detectors_resnet.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch.nn as nn
  3. import torch.utils.checkpoint as cp
  4. from mmcv.cnn import build_conv_layer, build_norm_layer
  5. from mmengine.logging import MMLogger
  6. from mmengine.model import Sequential, constant_init, kaiming_init
  7. from mmengine.runner.checkpoint import load_checkpoint
  8. from torch.nn.modules.batchnorm import _BatchNorm
  9. from mmdet.registry import MODELS
  10. from .resnet import BasicBlock
  11. from .resnet import Bottleneck as _Bottleneck
  12. from .resnet import ResNet
  13. class Bottleneck(_Bottleneck):
  14. r"""Bottleneck for the ResNet backbone in `DetectoRS
  15. <https://arxiv.org/pdf/2006.02334.pdf>`_.
  16. This bottleneck allows the users to specify whether to use
  17. SAC (Switchable Atrous Convolution) and RFP (Recursive Feature Pyramid).
  18. Args:
  19. inplanes (int): The number of input channels.
  20. planes (int): The number of output channels before expansion.
  21. rfp_inplanes (int, optional): The number of channels from RFP.
  22. Default: None. If specified, an additional conv layer will be
  23. added for ``rfp_feat``. Otherwise, the structure is the same as
  24. base class.
  25. sac (dict, optional): Dictionary to construct SAC. Default: None.
  26. init_cfg (dict or list[dict], optional): Initialization config dict.
  27. Default: None
  28. """
  29. expansion = 4
  30. def __init__(self,
  31. inplanes,
  32. planes,
  33. rfp_inplanes=None,
  34. sac=None,
  35. init_cfg=None,
  36. **kwargs):
  37. super(Bottleneck, self).__init__(
  38. inplanes, planes, init_cfg=init_cfg, **kwargs)
  39. assert sac is None or isinstance(sac, dict)
  40. self.sac = sac
  41. self.with_sac = sac is not None
  42. if self.with_sac:
  43. self.conv2 = build_conv_layer(
  44. self.sac,
  45. planes,
  46. planes,
  47. kernel_size=3,
  48. stride=self.conv2_stride,
  49. padding=self.dilation,
  50. dilation=self.dilation,
  51. bias=False)
  52. self.rfp_inplanes = rfp_inplanes
  53. if self.rfp_inplanes:
  54. self.rfp_conv = build_conv_layer(
  55. None,
  56. self.rfp_inplanes,
  57. planes * self.expansion,
  58. 1,
  59. stride=1,
  60. bias=True)
  61. if init_cfg is None:
  62. self.init_cfg = dict(
  63. type='Constant', val=0, override=dict(name='rfp_conv'))
  64. def rfp_forward(self, x, rfp_feat):
  65. """The forward function that also takes the RFP features as input."""
  66. def _inner_forward(x):
  67. identity = x
  68. out = self.conv1(x)
  69. out = self.norm1(out)
  70. out = self.relu(out)
  71. if self.with_plugins:
  72. out = self.forward_plugin(out, self.after_conv1_plugin_names)
  73. out = self.conv2(out)
  74. out = self.norm2(out)
  75. out = self.relu(out)
  76. if self.with_plugins:
  77. out = self.forward_plugin(out, self.after_conv2_plugin_names)
  78. out = self.conv3(out)
  79. out = self.norm3(out)
  80. if self.with_plugins:
  81. out = self.forward_plugin(out, self.after_conv3_plugin_names)
  82. if self.downsample is not None:
  83. identity = self.downsample(x)
  84. out += identity
  85. return out
  86. if self.with_cp and x.requires_grad:
  87. out = cp.checkpoint(_inner_forward, x)
  88. else:
  89. out = _inner_forward(x)
  90. if self.rfp_inplanes:
  91. rfp_feat = self.rfp_conv(rfp_feat)
  92. out = out + rfp_feat
  93. out = self.relu(out)
  94. return out
  95. class ResLayer(Sequential):
  96. """ResLayer to build ResNet style backbone for RPF in detectoRS.
  97. The difference between this module and base class is that we pass
  98. ``rfp_inplanes`` to the first block.
  99. Args:
  100. block (nn.Module): block used to build ResLayer.
  101. inplanes (int): inplanes of block.
  102. planes (int): planes of block.
  103. num_blocks (int): number of blocks.
  104. stride (int): stride of the first block. Default: 1
  105. avg_down (bool): Use AvgPool instead of stride conv when
  106. downsampling in the bottleneck. Default: False
  107. conv_cfg (dict): dictionary to construct and config conv layer.
  108. Default: None
  109. norm_cfg (dict): dictionary to construct and config norm layer.
  110. Default: dict(type='BN')
  111. downsample_first (bool): Downsample at the first block or last block.
  112. False for Hourglass, True for ResNet. Default: True
  113. rfp_inplanes (int, optional): The number of channels from RFP.
  114. Default: None. If specified, an additional conv layer will be
  115. added for ``rfp_feat``. Otherwise, the structure is the same as
  116. base class.
  117. """
  118. def __init__(self,
  119. block,
  120. inplanes,
  121. planes,
  122. num_blocks,
  123. stride=1,
  124. avg_down=False,
  125. conv_cfg=None,
  126. norm_cfg=dict(type='BN'),
  127. downsample_first=True,
  128. rfp_inplanes=None,
  129. **kwargs):
  130. self.block = block
  131. assert downsample_first, f'downsample_first={downsample_first} is ' \
  132. 'not supported in DetectoRS'
  133. downsample = None
  134. if stride != 1 or inplanes != planes * block.expansion:
  135. downsample = []
  136. conv_stride = stride
  137. if avg_down and stride != 1:
  138. conv_stride = 1
  139. downsample.append(
  140. nn.AvgPool2d(
  141. kernel_size=stride,
  142. stride=stride,
  143. ceil_mode=True,
  144. count_include_pad=False))
  145. downsample.extend([
  146. build_conv_layer(
  147. conv_cfg,
  148. inplanes,
  149. planes * block.expansion,
  150. kernel_size=1,
  151. stride=conv_stride,
  152. bias=False),
  153. build_norm_layer(norm_cfg, planes * block.expansion)[1]
  154. ])
  155. downsample = nn.Sequential(*downsample)
  156. layers = []
  157. layers.append(
  158. block(
  159. inplanes=inplanes,
  160. planes=planes,
  161. stride=stride,
  162. downsample=downsample,
  163. conv_cfg=conv_cfg,
  164. norm_cfg=norm_cfg,
  165. rfp_inplanes=rfp_inplanes,
  166. **kwargs))
  167. inplanes = planes * block.expansion
  168. for _ in range(1, num_blocks):
  169. layers.append(
  170. block(
  171. inplanes=inplanes,
  172. planes=planes,
  173. stride=1,
  174. conv_cfg=conv_cfg,
  175. norm_cfg=norm_cfg,
  176. **kwargs))
  177. super(ResLayer, self).__init__(*layers)
  178. @MODELS.register_module()
  179. class DetectoRS_ResNet(ResNet):
  180. """ResNet backbone for DetectoRS.
  181. Args:
  182. sac (dict, optional): Dictionary to construct SAC (Switchable Atrous
  183. Convolution). Default: None.
  184. stage_with_sac (list): Which stage to use sac. Default: (False, False,
  185. False, False).
  186. rfp_inplanes (int, optional): The number of channels from RFP.
  187. Default: None. If specified, an additional conv layer will be
  188. added for ``rfp_feat``. Otherwise, the structure is the same as
  189. base class.
  190. output_img (bool): If ``True``, the input image will be inserted into
  191. the starting position of output. Default: False.
  192. """
  193. arch_settings = {
  194. 50: (Bottleneck, (3, 4, 6, 3)),
  195. 101: (Bottleneck, (3, 4, 23, 3)),
  196. 152: (Bottleneck, (3, 8, 36, 3))
  197. }
  198. def __init__(self,
  199. sac=None,
  200. stage_with_sac=(False, False, False, False),
  201. rfp_inplanes=None,
  202. output_img=False,
  203. pretrained=None,
  204. init_cfg=None,
  205. **kwargs):
  206. assert not (init_cfg and pretrained), \
  207. 'init_cfg and pretrained cannot be specified at the same time'
  208. self.pretrained = pretrained
  209. if init_cfg is not None:
  210. assert isinstance(init_cfg, dict), \
  211. f'init_cfg must be a dict, but got {type(init_cfg)}'
  212. if 'type' in init_cfg:
  213. assert init_cfg.get('type') == 'Pretrained', \
  214. 'Only can initialize module by loading a pretrained model'
  215. else:
  216. raise KeyError('`init_cfg` must contain the key "type"')
  217. self.pretrained = init_cfg.get('checkpoint')
  218. self.sac = sac
  219. self.stage_with_sac = stage_with_sac
  220. self.rfp_inplanes = rfp_inplanes
  221. self.output_img = output_img
  222. super(DetectoRS_ResNet, self).__init__(**kwargs)
  223. self.inplanes = self.stem_channels
  224. self.res_layers = []
  225. for i, num_blocks in enumerate(self.stage_blocks):
  226. stride = self.strides[i]
  227. dilation = self.dilations[i]
  228. dcn = self.dcn if self.stage_with_dcn[i] else None
  229. sac = self.sac if self.stage_with_sac[i] else None
  230. if self.plugins is not None:
  231. stage_plugins = self.make_stage_plugins(self.plugins, i)
  232. else:
  233. stage_plugins = None
  234. planes = self.base_channels * 2**i
  235. res_layer = self.make_res_layer(
  236. block=self.block,
  237. inplanes=self.inplanes,
  238. planes=planes,
  239. num_blocks=num_blocks,
  240. stride=stride,
  241. dilation=dilation,
  242. style=self.style,
  243. avg_down=self.avg_down,
  244. with_cp=self.with_cp,
  245. conv_cfg=self.conv_cfg,
  246. norm_cfg=self.norm_cfg,
  247. dcn=dcn,
  248. sac=sac,
  249. rfp_inplanes=rfp_inplanes if i > 0 else None,
  250. plugins=stage_plugins)
  251. self.inplanes = planes * self.block.expansion
  252. layer_name = f'layer{i + 1}'
  253. self.add_module(layer_name, res_layer)
  254. self.res_layers.append(layer_name)
  255. self._freeze_stages()
  256. # In order to be properly initialized by RFP
  257. def init_weights(self):
  258. # Calling this method will cause parameter initialization exception
  259. # super(DetectoRS_ResNet, self).init_weights()
  260. if isinstance(self.pretrained, str):
  261. logger = MMLogger.get_current_instance()
  262. load_checkpoint(self, self.pretrained, strict=False, logger=logger)
  263. elif self.pretrained is None:
  264. for m in self.modules():
  265. if isinstance(m, nn.Conv2d):
  266. kaiming_init(m)
  267. elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
  268. constant_init(m, 1)
  269. if self.dcn is not None:
  270. for m in self.modules():
  271. if isinstance(m, Bottleneck) and hasattr(
  272. m.conv2, 'conv_offset'):
  273. constant_init(m.conv2.conv_offset, 0)
  274. if self.zero_init_residual:
  275. for m in self.modules():
  276. if isinstance(m, Bottleneck):
  277. constant_init(m.norm3, 0)
  278. elif isinstance(m, BasicBlock):
  279. constant_init(m.norm2, 0)
  280. else:
  281. raise TypeError('pretrained must be a str or None')
  282. def make_res_layer(self, **kwargs):
  283. """Pack all blocks in a stage into a ``ResLayer`` for DetectoRS."""
  284. return ResLayer(**kwargs)
  285. def forward(self, x):
  286. """Forward function."""
  287. outs = list(super(DetectoRS_ResNet, self).forward(x))
  288. if self.output_img:
  289. outs.insert(0, x)
  290. return tuple(outs)
  291. def rfp_forward(self, x, rfp_feats):
  292. """Forward function for RFP."""
  293. if self.deep_stem:
  294. x = self.stem(x)
  295. else:
  296. x = self.conv1(x)
  297. x = self.norm1(x)
  298. x = self.relu(x)
  299. x = self.maxpool(x)
  300. outs = []
  301. for i, layer_name in enumerate(self.res_layers):
  302. res_layer = getattr(self, layer_name)
  303. rfp_feat = rfp_feats[i] if i > 0 else None
  304. for layer in res_layer:
  305. x = layer.rfp_forward(x, rfp_feat)
  306. if i in self.out_indices:
  307. outs.append(x)
  308. return tuple(outs)