bfp.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Tuple
  3. import torch.nn.functional as F
  4. from mmcv.cnn import ConvModule
  5. from mmcv.cnn.bricks import NonLocal2d
  6. from mmengine.model import BaseModule
  7. from torch import Tensor
  8. from mmdet.registry import MODELS
  9. from mmdet.utils import OptConfigType, OptMultiConfig
  10. @MODELS.register_module()
  11. class BFP(BaseModule):
  12. """BFP (Balanced Feature Pyramids)
  13. BFP takes multi-level features as inputs and gather them into a single one,
  14. then refine the gathered feature and scatter the refined results to
  15. multi-level features. This module is used in Libra R-CNN (CVPR 2019), see
  16. the paper `Libra R-CNN: Towards Balanced Learning for Object Detection
  17. <https://arxiv.org/abs/1904.02701>`_ for details.
  18. Args:
  19. in_channels (int): Number of input channels (feature maps of all levels
  20. should have the same channels).
  21. num_levels (int): Number of input feature levels.
  22. refine_level (int): Index of integration and refine level of BSF in
  23. multi-level features from bottom to top.
  24. refine_type (str): Type of the refine op, currently support
  25. [None, 'conv', 'non_local'].
  26. conv_cfg (:obj:`ConfigDict` or dict, optional): The config dict for
  27. convolution layers.
  28. norm_cfg (:obj:`ConfigDict` or dict, optional): The config dict for
  29. normalization layers.
  30. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or
  31. dict], optional): Initialization config dict.
  32. """
  33. def __init__(
  34. self,
  35. in_channels: int,
  36. num_levels: int,
  37. refine_level: int = 2,
  38. refine_type: str = None,
  39. conv_cfg: OptConfigType = None,
  40. norm_cfg: OptConfigType = None,
  41. init_cfg: OptMultiConfig = dict(
  42. type='Xavier', layer='Conv2d', distribution='uniform')
  43. ) -> None:
  44. super().__init__(init_cfg=init_cfg)
  45. assert refine_type in [None, 'conv', 'non_local']
  46. self.in_channels = in_channels
  47. self.num_levels = num_levels
  48. self.conv_cfg = conv_cfg
  49. self.norm_cfg = norm_cfg
  50. self.refine_level = refine_level
  51. self.refine_type = refine_type
  52. assert 0 <= self.refine_level < self.num_levels
  53. if self.refine_type == 'conv':
  54. self.refine = ConvModule(
  55. self.in_channels,
  56. self.in_channels,
  57. 3,
  58. padding=1,
  59. conv_cfg=self.conv_cfg,
  60. norm_cfg=self.norm_cfg)
  61. elif self.refine_type == 'non_local':
  62. self.refine = NonLocal2d(
  63. self.in_channels,
  64. reduction=1,
  65. use_scale=False,
  66. conv_cfg=self.conv_cfg,
  67. norm_cfg=self.norm_cfg)
  68. def forward(self, inputs: Tuple[Tensor]) -> Tuple[Tensor]:
  69. """Forward function."""
  70. assert len(inputs) == self.num_levels
  71. # step 1: gather multi-level features by resize and average
  72. feats = []
  73. gather_size = inputs[self.refine_level].size()[2:]
  74. for i in range(self.num_levels):
  75. if i < self.refine_level:
  76. gathered = F.adaptive_max_pool2d(
  77. inputs[i], output_size=gather_size)
  78. else:
  79. gathered = F.interpolate(
  80. inputs[i], size=gather_size, mode='nearest')
  81. feats.append(gathered)
  82. bsf = sum(feats) / len(feats)
  83. # step 2: refine gathered features
  84. if self.refine_type is not None:
  85. bsf = self.refine(bsf)
  86. # step 3: scatter refined features to multi-levels by a residual path
  87. outs = []
  88. for i in range(self.num_levels):
  89. out_size = inputs[i].size()[2:]
  90. if i < self.refine_level:
  91. residual = F.interpolate(bsf, size=out_size, mode='nearest')
  92. else:
  93. residual = F.adaptive_max_pool2d(bsf, output_size=out_size)
  94. outs.append(residual + inputs[i])
  95. return tuple(outs)