pafpn.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from mmcv.cnn import ConvModule
  5. from mmdet.registry import MODELS
  6. from .fpn import FPN
  7. @MODELS.register_module()
  8. class PAFPN(FPN):
  9. """Path Aggregation Network for Instance Segmentation.
  10. This is an implementation of the `PAFPN in Path Aggregation Network
  11. <https://arxiv.org/abs/1803.01534>`_.
  12. Args:
  13. in_channels (List[int]): Number of input channels per scale.
  14. out_channels (int): Number of output channels (used at each scale)
  15. num_outs (int): Number of output scales.
  16. start_level (int): Index of the start input backbone level used to
  17. build the feature pyramid. Default: 0.
  18. end_level (int): Index of the end input backbone level (exclusive) to
  19. build the feature pyramid. Default: -1, which means the last level.
  20. add_extra_convs (bool | str): If bool, it decides whether to add conv
  21. layers on top of the original feature maps. Default to False.
  22. If True, it is equivalent to `add_extra_convs='on_input'`.
  23. If str, it specifies the source feature map of the extra convs.
  24. Only the following options are allowed
  25. - 'on_input': Last feat map of neck inputs (i.e. backbone feature).
  26. - 'on_lateral': Last feature map after lateral convs.
  27. - 'on_output': The last output feature map after fpn convs.
  28. relu_before_extra_convs (bool): Whether to apply relu before the extra
  29. conv. Default: False.
  30. no_norm_on_lateral (bool): Whether to apply norm on lateral.
  31. Default: False.
  32. conv_cfg (dict): Config dict for convolution layer. Default: None.
  33. norm_cfg (dict): Config dict for normalization layer. Default: None.
  34. act_cfg (str): Config dict for activation layer in ConvModule.
  35. Default: None.
  36. init_cfg (dict or list[dict], optional): Initialization config dict.
  37. """
  38. def __init__(self,
  39. in_channels,
  40. out_channels,
  41. num_outs,
  42. start_level=0,
  43. end_level=-1,
  44. add_extra_convs=False,
  45. relu_before_extra_convs=False,
  46. no_norm_on_lateral=False,
  47. conv_cfg=None,
  48. norm_cfg=None,
  49. act_cfg=None,
  50. init_cfg=dict(
  51. type='Xavier', layer='Conv2d', distribution='uniform')):
  52. super(PAFPN, self).__init__(
  53. in_channels,
  54. out_channels,
  55. num_outs,
  56. start_level,
  57. end_level,
  58. add_extra_convs,
  59. relu_before_extra_convs,
  60. no_norm_on_lateral,
  61. conv_cfg,
  62. norm_cfg,
  63. act_cfg,
  64. init_cfg=init_cfg)
  65. # add extra bottom up pathway
  66. self.downsample_convs = nn.ModuleList()
  67. self.pafpn_convs = nn.ModuleList()
  68. for i in range(self.start_level + 1, self.backbone_end_level):
  69. d_conv = ConvModule(
  70. out_channels,
  71. out_channels,
  72. 3,
  73. stride=2,
  74. padding=1,
  75. conv_cfg=conv_cfg,
  76. norm_cfg=norm_cfg,
  77. act_cfg=act_cfg,
  78. inplace=False)
  79. pafpn_conv = ConvModule(
  80. out_channels,
  81. out_channels,
  82. 3,
  83. padding=1,
  84. conv_cfg=conv_cfg,
  85. norm_cfg=norm_cfg,
  86. act_cfg=act_cfg,
  87. inplace=False)
  88. self.downsample_convs.append(d_conv)
  89. self.pafpn_convs.append(pafpn_conv)
  90. def forward(self, inputs):
  91. """Forward function."""
  92. assert len(inputs) == len(self.in_channels)
  93. # build laterals
  94. laterals = [
  95. lateral_conv(inputs[i + self.start_level])
  96. for i, lateral_conv in enumerate(self.lateral_convs)
  97. ]
  98. # build top-down path
  99. used_backbone_levels = len(laterals)
  100. for i in range(used_backbone_levels - 1, 0, -1):
  101. prev_shape = laterals[i - 1].shape[2:]
  102. laterals[i - 1] = laterals[i - 1] + F.interpolate(
  103. laterals[i], size=prev_shape, mode='nearest')
  104. # build outputs
  105. # part 1: from original levels
  106. inter_outs = [
  107. self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
  108. ]
  109. # part 2: add bottom-up path
  110. for i in range(0, used_backbone_levels - 1):
  111. inter_outs[i + 1] = inter_outs[i + 1] + \
  112. self.downsample_convs[i](inter_outs[i])
  113. outs = []
  114. outs.append(inter_outs[0])
  115. outs.extend([
  116. self.pafpn_convs[i - 1](inter_outs[i])
  117. for i in range(1, used_backbone_levels)
  118. ])
  119. # part 3: add extra levels
  120. if self.num_outs > len(outs):
  121. # use max pool to get more levels on top of outputs
  122. # (e.g., Faster R-CNN, Mask R-CNN)
  123. if not self.add_extra_convs:
  124. for i in range(self.num_outs - used_backbone_levels):
  125. outs.append(F.max_pool2d(outs[-1], 1, stride=2))
  126. # add conv layers on top of original feature maps (RetinaNet)
  127. else:
  128. if self.add_extra_convs == 'on_input':
  129. orig = inputs[self.backbone_end_level - 1]
  130. outs.append(self.fpn_convs[used_backbone_levels](orig))
  131. elif self.add_extra_convs == 'on_lateral':
  132. outs.append(self.fpn_convs[used_backbone_levels](
  133. laterals[-1]))
  134. elif self.add_extra_convs == 'on_output':
  135. outs.append(self.fpn_convs[used_backbone_levels](outs[-1]))
  136. else:
  137. raise NotImplementedError
  138. for i in range(used_backbone_levels + 1, self.num_outs):
  139. if self.relu_before_extra_convs:
  140. outs.append(self.fpn_convs[i](F.relu(outs[-1])))
  141. else:
  142. outs.append(self.fpn_convs[i](outs[-1]))
  143. return tuple(outs)