123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Tuple
- import torch.nn.functional as F
- from mmcv.cnn import ConvModule
- from mmcv.cnn.bricks import NonLocal2d
- from mmengine.model import BaseModule
- from torch import Tensor
- from mmdet.registry import MODELS
- from mmdet.utils import OptConfigType, OptMultiConfig
- @MODELS.register_module()
- class BFP(BaseModule):
- """BFP (Balanced Feature Pyramids)
- BFP takes multi-level features as inputs and gather them into a single one,
- then refine the gathered feature and scatter the refined results to
- multi-level features. This module is used in Libra R-CNN (CVPR 2019), see
- the paper `Libra R-CNN: Towards Balanced Learning for Object Detection
- <https://arxiv.org/abs/1904.02701>`_ for details.
- Args:
- in_channels (int): Number of input channels (feature maps of all levels
- should have the same channels).
- num_levels (int): Number of input feature levels.
- refine_level (int): Index of integration and refine level of BSF in
- multi-level features from bottom to top.
- refine_type (str): Type of the refine op, currently support
- [None, 'conv', 'non_local'].
- conv_cfg (:obj:`ConfigDict` or dict, optional): The config dict for
- convolution layers.
- norm_cfg (:obj:`ConfigDict` or dict, optional): The config dict for
- normalization layers.
- init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or
- dict], optional): Initialization config dict.
- """
- def __init__(
- self,
- in_channels: int,
- num_levels: int,
- refine_level: int = 2,
- refine_type: str = None,
- conv_cfg: OptConfigType = None,
- norm_cfg: OptConfigType = None,
- init_cfg: OptMultiConfig = dict(
- type='Xavier', layer='Conv2d', distribution='uniform')
- ) -> None:
- super().__init__(init_cfg=init_cfg)
- assert refine_type in [None, 'conv', 'non_local']
- self.in_channels = in_channels
- self.num_levels = num_levels
- self.conv_cfg = conv_cfg
- self.norm_cfg = norm_cfg
- self.refine_level = refine_level
- self.refine_type = refine_type
- assert 0 <= self.refine_level < self.num_levels
- if self.refine_type == 'conv':
- self.refine = ConvModule(
- self.in_channels,
- self.in_channels,
- 3,
- padding=1,
- conv_cfg=self.conv_cfg,
- norm_cfg=self.norm_cfg)
- elif self.refine_type == 'non_local':
- self.refine = NonLocal2d(
- self.in_channels,
- reduction=1,
- use_scale=False,
- conv_cfg=self.conv_cfg,
- norm_cfg=self.norm_cfg)
- def forward(self, inputs: Tuple[Tensor]) -> Tuple[Tensor]:
- """Forward function."""
- assert len(inputs) == self.num_levels
- # step 1: gather multi-level features by resize and average
- feats = []
- gather_size = inputs[self.refine_level].size()[2:]
- for i in range(self.num_levels):
- if i < self.refine_level:
- gathered = F.adaptive_max_pool2d(
- inputs[i], output_size=gather_size)
- else:
- gathered = F.interpolate(
- inputs[i], size=gather_size, mode='nearest')
- feats.append(gathered)
- bsf = sum(feats) / len(feats)
- # step 2: refine gathered features
- if self.refine_type is not None:
- bsf = self.refine(bsf)
- # step 3: scatter refined features to multi-levels by a residual path
- outs = []
- for i in range(self.num_levels):
- out_size = inputs[i].size()[2:]
- if i < self.refine_level:
- residual = F.interpolate(bsf, size=out_size, mode='nearest')
- else:
- residual = F.adaptive_max_pool2d(bsf, output_size=out_size)
- outs.append(residual + inputs[i])
- return tuple(outs)
|