fpg.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from mmcv.cnn import ConvModule
  5. from mmengine.model import BaseModule
  6. from mmdet.registry import MODELS
  7. class Transition(BaseModule):
  8. """Base class for transition.
  9. Args:
  10. in_channels (int): Number of input channels.
  11. out_channels (int): Number of output channels.
  12. """
  13. def __init__(self, in_channels, out_channels, init_cfg=None):
  14. super().__init__(init_cfg)
  15. self.in_channels = in_channels
  16. self.out_channels = out_channels
  17. def forward(x):
  18. pass
  19. class UpInterpolationConv(Transition):
  20. """A transition used for up-sampling.
  21. Up-sample the input by interpolation then refines the feature by
  22. a convolution layer.
  23. Args:
  24. in_channels (int): Number of input channels.
  25. out_channels (int): Number of output channels.
  26. scale_factor (int): Up-sampling factor. Default: 2.
  27. mode (int): Interpolation mode. Default: nearest.
  28. align_corners (bool): Whether align corners when interpolation.
  29. Default: None.
  30. kernel_size (int): Kernel size for the conv. Default: 3.
  31. """
  32. def __init__(self,
  33. in_channels,
  34. out_channels,
  35. scale_factor=2,
  36. mode='nearest',
  37. align_corners=None,
  38. kernel_size=3,
  39. init_cfg=None,
  40. **kwargs):
  41. super().__init__(in_channels, out_channels, init_cfg)
  42. self.mode = mode
  43. self.scale_factor = scale_factor
  44. self.align_corners = align_corners
  45. self.conv = ConvModule(
  46. in_channels,
  47. out_channels,
  48. kernel_size,
  49. padding=(kernel_size - 1) // 2,
  50. **kwargs)
  51. def forward(self, x):
  52. x = F.interpolate(
  53. x,
  54. scale_factor=self.scale_factor,
  55. mode=self.mode,
  56. align_corners=self.align_corners)
  57. x = self.conv(x)
  58. return x
  59. class LastConv(Transition):
  60. """A transition used for refining the output of the last stage.
  61. Args:
  62. in_channels (int): Number of input channels.
  63. out_channels (int): Number of output channels.
  64. num_inputs (int): Number of inputs of the FPN features.
  65. kernel_size (int): Kernel size for the conv. Default: 3.
  66. """
  67. def __init__(self,
  68. in_channels,
  69. out_channels,
  70. num_inputs,
  71. kernel_size=3,
  72. init_cfg=None,
  73. **kwargs):
  74. super().__init__(in_channels, out_channels, init_cfg)
  75. self.num_inputs = num_inputs
  76. self.conv_out = ConvModule(
  77. in_channels,
  78. out_channels,
  79. kernel_size,
  80. padding=(kernel_size - 1) // 2,
  81. **kwargs)
  82. def forward(self, inputs):
  83. assert len(inputs) == self.num_inputs
  84. return self.conv_out(inputs[-1])
  85. @MODELS.register_module()
  86. class FPG(BaseModule):
  87. """FPG.
  88. Implementation of `Feature Pyramid Grids (FPG)
  89. <https://arxiv.org/abs/2004.03580>`_.
  90. This implementation only gives the basic structure stated in the paper.
  91. But users can implement different type of transitions to fully explore the
  92. the potential power of the structure of FPG.
  93. Args:
  94. in_channels (int): Number of input channels (feature maps of all levels
  95. should have the same channels).
  96. out_channels (int): Number of output channels (used at each scale)
  97. num_outs (int): Number of output scales.
  98. stack_times (int): The number of times the pyramid architecture will
  99. be stacked.
  100. paths (list[str]): Specify the path order of each stack level.
  101. Each element in the list should be either 'bu' (bottom-up) or
  102. 'td' (top-down).
  103. inter_channels (int): Number of inter channels.
  104. same_up_trans (dict): Transition that goes down at the same stage.
  105. same_down_trans (dict): Transition that goes up at the same stage.
  106. across_lateral_trans (dict): Across-pathway same-stage
  107. across_down_trans (dict): Across-pathway bottom-up connection.
  108. across_up_trans (dict): Across-pathway top-down connection.
  109. across_skip_trans (dict): Across-pathway skip connection.
  110. output_trans (dict): Transition that trans the output of the
  111. last stage.
  112. start_level (int): Index of the start input backbone level used to
  113. build the feature pyramid. Default: 0.
  114. end_level (int): Index of the end input backbone level (exclusive) to
  115. build the feature pyramid. Default: -1, which means the last level.
  116. add_extra_convs (bool): It decides whether to add conv
  117. layers on top of the original feature maps. Default to False.
  118. If True, its actual mode is specified by `extra_convs_on_inputs`.
  119. norm_cfg (dict): Config dict for normalization layer. Default: None.
  120. init_cfg (dict or list[dict], optional): Initialization config dict.
  121. """
  122. transition_types = {
  123. 'conv': ConvModule,
  124. 'interpolation_conv': UpInterpolationConv,
  125. 'last_conv': LastConv,
  126. }
  127. def __init__(self,
  128. in_channels,
  129. out_channels,
  130. num_outs,
  131. stack_times,
  132. paths,
  133. inter_channels=None,
  134. same_down_trans=None,
  135. same_up_trans=dict(
  136. type='conv', kernel_size=3, stride=2, padding=1),
  137. across_lateral_trans=dict(type='conv', kernel_size=1),
  138. across_down_trans=dict(type='conv', kernel_size=3),
  139. across_up_trans=None,
  140. across_skip_trans=dict(type='identity'),
  141. output_trans=dict(type='last_conv', kernel_size=3),
  142. start_level=0,
  143. end_level=-1,
  144. add_extra_convs=False,
  145. norm_cfg=None,
  146. skip_inds=None,
  147. init_cfg=[
  148. dict(type='Caffe2Xavier', layer='Conv2d'),
  149. dict(
  150. type='Constant',
  151. layer=[
  152. '_BatchNorm', '_InstanceNorm', 'GroupNorm',
  153. 'LayerNorm'
  154. ],
  155. val=1.0)
  156. ]):
  157. super(FPG, self).__init__(init_cfg)
  158. assert isinstance(in_channels, list)
  159. self.in_channels = in_channels
  160. self.out_channels = out_channels
  161. self.num_ins = len(in_channels)
  162. self.num_outs = num_outs
  163. if inter_channels is None:
  164. self.inter_channels = [out_channels for _ in range(num_outs)]
  165. elif isinstance(inter_channels, int):
  166. self.inter_channels = [inter_channels for _ in range(num_outs)]
  167. else:
  168. assert isinstance(inter_channels, list)
  169. assert len(inter_channels) == num_outs
  170. self.inter_channels = inter_channels
  171. self.stack_times = stack_times
  172. self.paths = paths
  173. assert isinstance(paths, list) and len(paths) == stack_times
  174. for d in paths:
  175. assert d in ('bu', 'td')
  176. self.same_down_trans = same_down_trans
  177. self.same_up_trans = same_up_trans
  178. self.across_lateral_trans = across_lateral_trans
  179. self.across_down_trans = across_down_trans
  180. self.across_up_trans = across_up_trans
  181. self.output_trans = output_trans
  182. self.across_skip_trans = across_skip_trans
  183. self.with_bias = norm_cfg is None
  184. # skip inds must be specified if across skip trans is not None
  185. if self.across_skip_trans is not None:
  186. skip_inds is not None
  187. self.skip_inds = skip_inds
  188. assert len(self.skip_inds[0]) <= self.stack_times
  189. if end_level == -1 or end_level == self.num_ins - 1:
  190. self.backbone_end_level = self.num_ins
  191. assert num_outs >= self.num_ins - start_level
  192. else:
  193. # if end_level is not the last level, no extra level is allowed
  194. self.backbone_end_level = end_level + 1
  195. assert end_level < self.num_ins
  196. assert num_outs == end_level - start_level + 1
  197. self.start_level = start_level
  198. self.end_level = end_level
  199. self.add_extra_convs = add_extra_convs
  200. # build lateral 1x1 convs to reduce channels
  201. self.lateral_convs = nn.ModuleList()
  202. for i in range(self.start_level, self.backbone_end_level):
  203. l_conv = nn.Conv2d(self.in_channels[i],
  204. self.inter_channels[i - self.start_level], 1)
  205. self.lateral_convs.append(l_conv)
  206. extra_levels = num_outs - self.backbone_end_level + self.start_level
  207. self.extra_downsamples = nn.ModuleList()
  208. for i in range(extra_levels):
  209. if self.add_extra_convs:
  210. fpn_idx = self.backbone_end_level - self.start_level + i
  211. extra_conv = nn.Conv2d(
  212. self.inter_channels[fpn_idx - 1],
  213. self.inter_channels[fpn_idx],
  214. 3,
  215. stride=2,
  216. padding=1)
  217. self.extra_downsamples.append(extra_conv)
  218. else:
  219. self.extra_downsamples.append(nn.MaxPool2d(1, stride=2))
  220. self.fpn_transitions = nn.ModuleList() # stack times
  221. for s in range(self.stack_times):
  222. stage_trans = nn.ModuleList() # num of feature levels
  223. for i in range(self.num_outs):
  224. # same, across_lateral, across_down, across_up
  225. trans = nn.ModuleDict()
  226. if s in self.skip_inds[i]:
  227. stage_trans.append(trans)
  228. continue
  229. # build same-stage down trans (used in bottom-up paths)
  230. if i == 0 or self.same_up_trans is None:
  231. same_up_trans = None
  232. else:
  233. same_up_trans = self.build_trans(
  234. self.same_up_trans, self.inter_channels[i - 1],
  235. self.inter_channels[i])
  236. trans['same_up'] = same_up_trans
  237. # build same-stage up trans (used in top-down paths)
  238. if i == self.num_outs - 1 or self.same_down_trans is None:
  239. same_down_trans = None
  240. else:
  241. same_down_trans = self.build_trans(
  242. self.same_down_trans, self.inter_channels[i + 1],
  243. self.inter_channels[i])
  244. trans['same_down'] = same_down_trans
  245. # build across lateral trans
  246. across_lateral_trans = self.build_trans(
  247. self.across_lateral_trans, self.inter_channels[i],
  248. self.inter_channels[i])
  249. trans['across_lateral'] = across_lateral_trans
  250. # build across down trans
  251. if i == self.num_outs - 1 or self.across_down_trans is None:
  252. across_down_trans = None
  253. else:
  254. across_down_trans = self.build_trans(
  255. self.across_down_trans, self.inter_channels[i + 1],
  256. self.inter_channels[i])
  257. trans['across_down'] = across_down_trans
  258. # build across up trans
  259. if i == 0 or self.across_up_trans is None:
  260. across_up_trans = None
  261. else:
  262. across_up_trans = self.build_trans(
  263. self.across_up_trans, self.inter_channels[i - 1],
  264. self.inter_channels[i])
  265. trans['across_up'] = across_up_trans
  266. if self.across_skip_trans is None:
  267. across_skip_trans = None
  268. else:
  269. across_skip_trans = self.build_trans(
  270. self.across_skip_trans, self.inter_channels[i - 1],
  271. self.inter_channels[i])
  272. trans['across_skip'] = across_skip_trans
  273. # build across_skip trans
  274. stage_trans.append(trans)
  275. self.fpn_transitions.append(stage_trans)
  276. self.output_transition = nn.ModuleList() # output levels
  277. for i in range(self.num_outs):
  278. trans = self.build_trans(
  279. self.output_trans,
  280. self.inter_channels[i],
  281. self.out_channels,
  282. num_inputs=self.stack_times + 1)
  283. self.output_transition.append(trans)
  284. self.relu = nn.ReLU(inplace=True)
  285. def build_trans(self, cfg, in_channels, out_channels, **extra_args):
  286. cfg_ = cfg.copy()
  287. trans_type = cfg_.pop('type')
  288. trans_cls = self.transition_types[trans_type]
  289. return trans_cls(in_channels, out_channels, **cfg_, **extra_args)
  290. def fuse(self, fuse_dict):
  291. out = None
  292. for item in fuse_dict.values():
  293. if item is not None:
  294. if out is None:
  295. out = item
  296. else:
  297. out = out + item
  298. return out
  299. def forward(self, inputs):
  300. assert len(inputs) == len(self.in_channels)
  301. # build all levels from original feature maps
  302. feats = [
  303. lateral_conv(inputs[i + self.start_level])
  304. for i, lateral_conv in enumerate(self.lateral_convs)
  305. ]
  306. for downsample in self.extra_downsamples:
  307. feats.append(downsample(feats[-1]))
  308. outs = [feats]
  309. for i in range(self.stack_times):
  310. current_outs = outs[-1]
  311. next_outs = []
  312. direction = self.paths[i]
  313. for j in range(self.num_outs):
  314. if i in self.skip_inds[j]:
  315. next_outs.append(outs[-1][j])
  316. continue
  317. # feature level
  318. if direction == 'td':
  319. lvl = self.num_outs - j - 1
  320. else:
  321. lvl = j
  322. # get transitions
  323. if direction == 'td':
  324. same_trans = self.fpn_transitions[i][lvl]['same_down']
  325. else:
  326. same_trans = self.fpn_transitions[i][lvl]['same_up']
  327. across_lateral_trans = self.fpn_transitions[i][lvl][
  328. 'across_lateral']
  329. across_down_trans = self.fpn_transitions[i][lvl]['across_down']
  330. across_up_trans = self.fpn_transitions[i][lvl]['across_up']
  331. across_skip_trans = self.fpn_transitions[i][lvl]['across_skip']
  332. # init output
  333. to_fuse = dict(
  334. same=None, lateral=None, across_up=None, across_down=None)
  335. # same downsample/upsample
  336. if same_trans is not None:
  337. to_fuse['same'] = same_trans(next_outs[-1])
  338. # across lateral
  339. if across_lateral_trans is not None:
  340. to_fuse['lateral'] = across_lateral_trans(
  341. current_outs[lvl])
  342. # across downsample
  343. if lvl > 0 and across_up_trans is not None:
  344. to_fuse['across_up'] = across_up_trans(current_outs[lvl -
  345. 1])
  346. # across upsample
  347. if (lvl < self.num_outs - 1 and across_down_trans is not None):
  348. to_fuse['across_down'] = across_down_trans(
  349. current_outs[lvl + 1])
  350. if across_skip_trans is not None:
  351. to_fuse['across_skip'] = across_skip_trans(outs[0][lvl])
  352. x = self.fuse(to_fuse)
  353. next_outs.append(x)
  354. if direction == 'td':
  355. outs.append(next_outs[::-1])
  356. else:
  357. outs.append(next_outs)
  358. # output trans
  359. final_outs = []
  360. for i in range(self.num_outs):
  361. lvl_out_list = []
  362. for s in range(len(outs)):
  363. lvl_out_list.append(outs[s][i])
  364. lvl_out = self.output_transition[i](lvl_out_list)
  365. final_outs.append(lvl_out)
  366. return final_outs