fpn_carafe.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch.nn as nn
  3. from mmcv.cnn import ConvModule, build_upsample_layer
  4. from mmcv.ops.carafe import CARAFEPack
  5. from mmengine.model import BaseModule, ModuleList, xavier_init
  6. from mmdet.registry import MODELS
  7. @MODELS.register_module()
  8. class FPN_CARAFE(BaseModule):
  9. """FPN_CARAFE is a more flexible implementation of FPN. It allows more
  10. choice for upsample methods during the top-down pathway.
  11. It can reproduce the performance of ICCV 2019 paper
  12. CARAFE: Content-Aware ReAssembly of FEatures
  13. Please refer to https://arxiv.org/abs/1905.02188 for more details.
  14. Args:
  15. in_channels (list[int]): Number of channels for each input feature map.
  16. out_channels (int): Output channels of feature pyramids.
  17. num_outs (int): Number of output stages.
  18. start_level (int): Start level of feature pyramids.
  19. (Default: 0)
  20. end_level (int): End level of feature pyramids.
  21. (Default: -1 indicates the last level).
  22. norm_cfg (dict): Dictionary to construct and config norm layer.
  23. activate (str): Type of activation function in ConvModule
  24. (Default: None indicates w/o activation).
  25. order (dict): Order of components in ConvModule.
  26. upsample (str): Type of upsample layer.
  27. upsample_cfg (dict): Dictionary to construct and config upsample layer.
  28. init_cfg (dict or list[dict], optional): Initialization config dict.
  29. Default: None
  30. """
  31. def __init__(self,
  32. in_channels,
  33. out_channels,
  34. num_outs,
  35. start_level=0,
  36. end_level=-1,
  37. norm_cfg=None,
  38. act_cfg=None,
  39. order=('conv', 'norm', 'act'),
  40. upsample_cfg=dict(
  41. type='carafe',
  42. up_kernel=5,
  43. up_group=1,
  44. encoder_kernel=3,
  45. encoder_dilation=1),
  46. init_cfg=None):
  47. assert init_cfg is None, 'To prevent abnormal initialization ' \
  48. 'behavior, init_cfg is not allowed to be set'
  49. super(FPN_CARAFE, self).__init__(init_cfg)
  50. assert isinstance(in_channels, list)
  51. self.in_channels = in_channels
  52. self.out_channels = out_channels
  53. self.num_ins = len(in_channels)
  54. self.num_outs = num_outs
  55. self.norm_cfg = norm_cfg
  56. self.act_cfg = act_cfg
  57. self.with_bias = norm_cfg is None
  58. self.upsample_cfg = upsample_cfg.copy()
  59. self.upsample = self.upsample_cfg.get('type')
  60. self.relu = nn.ReLU(inplace=False)
  61. self.order = order
  62. assert order in [('conv', 'norm', 'act'), ('act', 'conv', 'norm')]
  63. assert self.upsample in [
  64. 'nearest', 'bilinear', 'deconv', 'pixel_shuffle', 'carafe', None
  65. ]
  66. if self.upsample in ['deconv', 'pixel_shuffle']:
  67. assert hasattr(
  68. self.upsample_cfg,
  69. 'upsample_kernel') and self.upsample_cfg.upsample_kernel > 0
  70. self.upsample_kernel = self.upsample_cfg.pop('upsample_kernel')
  71. if end_level == -1 or end_level == self.num_ins - 1:
  72. self.backbone_end_level = self.num_ins
  73. assert num_outs >= self.num_ins - start_level
  74. else:
  75. # if end_level is not the last level, no extra level is allowed
  76. self.backbone_end_level = end_level + 1
  77. assert end_level < self.num_ins
  78. assert num_outs == end_level - start_level + 1
  79. self.start_level = start_level
  80. self.end_level = end_level
  81. self.lateral_convs = ModuleList()
  82. self.fpn_convs = ModuleList()
  83. self.upsample_modules = ModuleList()
  84. for i in range(self.start_level, self.backbone_end_level):
  85. l_conv = ConvModule(
  86. in_channels[i],
  87. out_channels,
  88. 1,
  89. norm_cfg=norm_cfg,
  90. bias=self.with_bias,
  91. act_cfg=act_cfg,
  92. inplace=False,
  93. order=self.order)
  94. fpn_conv = ConvModule(
  95. out_channels,
  96. out_channels,
  97. 3,
  98. padding=1,
  99. norm_cfg=self.norm_cfg,
  100. bias=self.with_bias,
  101. act_cfg=act_cfg,
  102. inplace=False,
  103. order=self.order)
  104. if i != self.backbone_end_level - 1:
  105. upsample_cfg_ = self.upsample_cfg.copy()
  106. if self.upsample == 'deconv':
  107. upsample_cfg_.update(
  108. in_channels=out_channels,
  109. out_channels=out_channels,
  110. kernel_size=self.upsample_kernel,
  111. stride=2,
  112. padding=(self.upsample_kernel - 1) // 2,
  113. output_padding=(self.upsample_kernel - 1) // 2)
  114. elif self.upsample == 'pixel_shuffle':
  115. upsample_cfg_.update(
  116. in_channels=out_channels,
  117. out_channels=out_channels,
  118. scale_factor=2,
  119. upsample_kernel=self.upsample_kernel)
  120. elif self.upsample == 'carafe':
  121. upsample_cfg_.update(channels=out_channels, scale_factor=2)
  122. else:
  123. # suppress warnings
  124. align_corners = (None
  125. if self.upsample == 'nearest' else False)
  126. upsample_cfg_.update(
  127. scale_factor=2,
  128. mode=self.upsample,
  129. align_corners=align_corners)
  130. upsample_module = build_upsample_layer(upsample_cfg_)
  131. self.upsample_modules.append(upsample_module)
  132. self.lateral_convs.append(l_conv)
  133. self.fpn_convs.append(fpn_conv)
  134. # add extra conv layers (e.g., RetinaNet)
  135. extra_out_levels = (
  136. num_outs - self.backbone_end_level + self.start_level)
  137. if extra_out_levels >= 1:
  138. for i in range(extra_out_levels):
  139. in_channels = (
  140. self.in_channels[self.backbone_end_level -
  141. 1] if i == 0 else out_channels)
  142. extra_l_conv = ConvModule(
  143. in_channels,
  144. out_channels,
  145. 3,
  146. stride=2,
  147. padding=1,
  148. norm_cfg=norm_cfg,
  149. bias=self.with_bias,
  150. act_cfg=act_cfg,
  151. inplace=False,
  152. order=self.order)
  153. if self.upsample == 'deconv':
  154. upsampler_cfg_ = dict(
  155. in_channels=out_channels,
  156. out_channels=out_channels,
  157. kernel_size=self.upsample_kernel,
  158. stride=2,
  159. padding=(self.upsample_kernel - 1) // 2,
  160. output_padding=(self.upsample_kernel - 1) // 2)
  161. elif self.upsample == 'pixel_shuffle':
  162. upsampler_cfg_ = dict(
  163. in_channels=out_channels,
  164. out_channels=out_channels,
  165. scale_factor=2,
  166. upsample_kernel=self.upsample_kernel)
  167. elif self.upsample == 'carafe':
  168. upsampler_cfg_ = dict(
  169. channels=out_channels,
  170. scale_factor=2,
  171. **self.upsample_cfg)
  172. else:
  173. # suppress warnings
  174. align_corners = (None
  175. if self.upsample == 'nearest' else False)
  176. upsampler_cfg_ = dict(
  177. scale_factor=2,
  178. mode=self.upsample,
  179. align_corners=align_corners)
  180. upsampler_cfg_['type'] = self.upsample
  181. upsample_module = build_upsample_layer(upsampler_cfg_)
  182. extra_fpn_conv = ConvModule(
  183. out_channels,
  184. out_channels,
  185. 3,
  186. padding=1,
  187. norm_cfg=self.norm_cfg,
  188. bias=self.with_bias,
  189. act_cfg=act_cfg,
  190. inplace=False,
  191. order=self.order)
  192. self.upsample_modules.append(upsample_module)
  193. self.fpn_convs.append(extra_fpn_conv)
  194. self.lateral_convs.append(extra_l_conv)
  195. # default init_weights for conv(msra) and norm in ConvModule
  196. def init_weights(self):
  197. """Initialize the weights of module."""
  198. super(FPN_CARAFE, self).init_weights()
  199. for m in self.modules():
  200. if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
  201. xavier_init(m, distribution='uniform')
  202. for m in self.modules():
  203. if isinstance(m, CARAFEPack):
  204. m.init_weights()
  205. def slice_as(self, src, dst):
  206. """Slice ``src`` as ``dst``
  207. Note:
  208. ``src`` should have the same or larger size than ``dst``.
  209. Args:
  210. src (torch.Tensor): Tensors to be sliced.
  211. dst (torch.Tensor): ``src`` will be sliced to have the same
  212. size as ``dst``.
  213. Returns:
  214. torch.Tensor: Sliced tensor.
  215. """
  216. assert (src.size(2) >= dst.size(2)) and (src.size(3) >= dst.size(3))
  217. if src.size(2) == dst.size(2) and src.size(3) == dst.size(3):
  218. return src
  219. else:
  220. return src[:, :, :dst.size(2), :dst.size(3)]
  221. def tensor_add(self, a, b):
  222. """Add tensors ``a`` and ``b`` that might have different sizes."""
  223. if a.size() == b.size():
  224. c = a + b
  225. else:
  226. c = a + self.slice_as(b, a)
  227. return c
  228. def forward(self, inputs):
  229. """Forward function."""
  230. assert len(inputs) == len(self.in_channels)
  231. # build laterals
  232. laterals = []
  233. for i, lateral_conv in enumerate(self.lateral_convs):
  234. if i <= self.backbone_end_level - self.start_level:
  235. input = inputs[min(i + self.start_level, len(inputs) - 1)]
  236. else:
  237. input = laterals[-1]
  238. lateral = lateral_conv(input)
  239. laterals.append(lateral)
  240. # build top-down path
  241. for i in range(len(laterals) - 1, 0, -1):
  242. if self.upsample is not None:
  243. upsample_feat = self.upsample_modules[i - 1](laterals[i])
  244. else:
  245. upsample_feat = laterals[i]
  246. laterals[i - 1] = self.tensor_add(laterals[i - 1], upsample_feat)
  247. # build outputs
  248. num_conv_outs = len(self.fpn_convs)
  249. outs = []
  250. for i in range(num_conv_outs):
  251. out = self.fpn_convs[i](laterals[i])
  252. outs.append(out)
  253. return tuple(outs)