trident_resnet.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import torch.utils.checkpoint as cp
  6. from mmcv.cnn import build_conv_layer, build_norm_layer
  7. from mmengine.model import BaseModule
  8. from torch.nn.modules.utils import _pair
  9. from mmdet.models.backbones.resnet import Bottleneck, ResNet
  10. from mmdet.registry import MODELS
  11. class TridentConv(BaseModule):
  12. """Trident Convolution Module.
  13. Args:
  14. in_channels (int): Number of channels in input.
  15. out_channels (int): Number of channels in output.
  16. kernel_size (int): Size of convolution kernel.
  17. stride (int, optional): Convolution stride. Default: 1.
  18. trident_dilations (tuple[int, int, int], optional): Dilations of
  19. different trident branch. Default: (1, 2, 3).
  20. test_branch_idx (int, optional): In inference, all 3 branches will
  21. be used if `test_branch_idx==-1`, otherwise only branch with
  22. index `test_branch_idx` will be used. Default: 1.
  23. bias (bool, optional): Whether to use bias in convolution or not.
  24. Default: False.
  25. init_cfg (dict or list[dict], optional): Initialization config dict.
  26. Default: None
  27. """
  28. def __init__(self,
  29. in_channels,
  30. out_channels,
  31. kernel_size,
  32. stride=1,
  33. trident_dilations=(1, 2, 3),
  34. test_branch_idx=1,
  35. bias=False,
  36. init_cfg=None):
  37. super(TridentConv, self).__init__(init_cfg)
  38. self.num_branch = len(trident_dilations)
  39. self.with_bias = bias
  40. self.test_branch_idx = test_branch_idx
  41. self.stride = _pair(stride)
  42. self.kernel_size = _pair(kernel_size)
  43. self.paddings = _pair(trident_dilations)
  44. self.dilations = trident_dilations
  45. self.in_channels = in_channels
  46. self.out_channels = out_channels
  47. self.bias = bias
  48. self.weight = nn.Parameter(
  49. torch.Tensor(out_channels, in_channels, *self.kernel_size))
  50. if bias:
  51. self.bias = nn.Parameter(torch.Tensor(out_channels))
  52. else:
  53. self.bias = None
  54. def extra_repr(self):
  55. tmpstr = f'in_channels={self.in_channels}'
  56. tmpstr += f', out_channels={self.out_channels}'
  57. tmpstr += f', kernel_size={self.kernel_size}'
  58. tmpstr += f', num_branch={self.num_branch}'
  59. tmpstr += f', test_branch_idx={self.test_branch_idx}'
  60. tmpstr += f', stride={self.stride}'
  61. tmpstr += f', paddings={self.paddings}'
  62. tmpstr += f', dilations={self.dilations}'
  63. tmpstr += f', bias={self.bias}'
  64. return tmpstr
  65. def forward(self, inputs):
  66. if self.training or self.test_branch_idx == -1:
  67. outputs = [
  68. F.conv2d(input, self.weight, self.bias, self.stride, padding,
  69. dilation) for input, dilation, padding in zip(
  70. inputs, self.dilations, self.paddings)
  71. ]
  72. else:
  73. assert len(inputs) == 1
  74. outputs = [
  75. F.conv2d(inputs[0], self.weight, self.bias, self.stride,
  76. self.paddings[self.test_branch_idx],
  77. self.dilations[self.test_branch_idx])
  78. ]
  79. return outputs
  80. # Since TridentNet is defined over ResNet50 and ResNet101, here we
  81. # only support TridentBottleneckBlock.
  82. class TridentBottleneck(Bottleneck):
  83. """BottleBlock for TridentResNet.
  84. Args:
  85. trident_dilations (tuple[int, int, int]): Dilations of different
  86. trident branch.
  87. test_branch_idx (int): In inference, all 3 branches will be used
  88. if `test_branch_idx==-1`, otherwise only branch with index
  89. `test_branch_idx` will be used.
  90. concat_output (bool): Whether to concat the output list to a Tensor.
  91. `True` only in the last Block.
  92. """
  93. def __init__(self, trident_dilations, test_branch_idx, concat_output,
  94. **kwargs):
  95. super(TridentBottleneck, self).__init__(**kwargs)
  96. self.trident_dilations = trident_dilations
  97. self.num_branch = len(trident_dilations)
  98. self.concat_output = concat_output
  99. self.test_branch_idx = test_branch_idx
  100. self.conv2 = TridentConv(
  101. self.planes,
  102. self.planes,
  103. kernel_size=3,
  104. stride=self.conv2_stride,
  105. bias=False,
  106. trident_dilations=self.trident_dilations,
  107. test_branch_idx=test_branch_idx,
  108. init_cfg=dict(
  109. type='Kaiming',
  110. distribution='uniform',
  111. mode='fan_in',
  112. override=dict(name='conv2')))
  113. def forward(self, x):
  114. def _inner_forward(x):
  115. num_branch = (
  116. self.num_branch
  117. if self.training or self.test_branch_idx == -1 else 1)
  118. identity = x
  119. if not isinstance(x, list):
  120. x = (x, ) * num_branch
  121. identity = x
  122. if self.downsample is not None:
  123. identity = [self.downsample(b) for b in x]
  124. out = [self.conv1(b) for b in x]
  125. out = [self.norm1(b) for b in out]
  126. out = [self.relu(b) for b in out]
  127. if self.with_plugins:
  128. for k in range(len(out)):
  129. out[k] = self.forward_plugin(out[k],
  130. self.after_conv1_plugin_names)
  131. out = self.conv2(out)
  132. out = [self.norm2(b) for b in out]
  133. out = [self.relu(b) for b in out]
  134. if self.with_plugins:
  135. for k in range(len(out)):
  136. out[k] = self.forward_plugin(out[k],
  137. self.after_conv2_plugin_names)
  138. out = [self.conv3(b) for b in out]
  139. out = [self.norm3(b) for b in out]
  140. if self.with_plugins:
  141. for k in range(len(out)):
  142. out[k] = self.forward_plugin(out[k],
  143. self.after_conv3_plugin_names)
  144. out = [
  145. out_b + identity_b for out_b, identity_b in zip(out, identity)
  146. ]
  147. return out
  148. if self.with_cp and x.requires_grad:
  149. out = cp.checkpoint(_inner_forward, x)
  150. else:
  151. out = _inner_forward(x)
  152. out = [self.relu(b) for b in out]
  153. if self.concat_output:
  154. out = torch.cat(out, dim=0)
  155. return out
  156. def make_trident_res_layer(block,
  157. inplanes,
  158. planes,
  159. num_blocks,
  160. stride=1,
  161. trident_dilations=(1, 2, 3),
  162. style='pytorch',
  163. with_cp=False,
  164. conv_cfg=None,
  165. norm_cfg=dict(type='BN'),
  166. dcn=None,
  167. plugins=None,
  168. test_branch_idx=-1):
  169. """Build Trident Res Layers."""
  170. downsample = None
  171. if stride != 1 or inplanes != planes * block.expansion:
  172. downsample = []
  173. conv_stride = stride
  174. downsample.extend([
  175. build_conv_layer(
  176. conv_cfg,
  177. inplanes,
  178. planes * block.expansion,
  179. kernel_size=1,
  180. stride=conv_stride,
  181. bias=False),
  182. build_norm_layer(norm_cfg, planes * block.expansion)[1]
  183. ])
  184. downsample = nn.Sequential(*downsample)
  185. layers = []
  186. for i in range(num_blocks):
  187. layers.append(
  188. block(
  189. inplanes=inplanes,
  190. planes=planes,
  191. stride=stride if i == 0 else 1,
  192. trident_dilations=trident_dilations,
  193. downsample=downsample if i == 0 else None,
  194. style=style,
  195. with_cp=with_cp,
  196. conv_cfg=conv_cfg,
  197. norm_cfg=norm_cfg,
  198. dcn=dcn,
  199. plugins=plugins,
  200. test_branch_idx=test_branch_idx,
  201. concat_output=True if i == num_blocks - 1 else False))
  202. inplanes = planes * block.expansion
  203. return nn.Sequential(*layers)
  204. @MODELS.register_module()
  205. class TridentResNet(ResNet):
  206. """The stem layer, stage 1 and stage 2 in Trident ResNet are identical to
  207. ResNet, while in stage 3, Trident BottleBlock is utilized to replace the
  208. normal BottleBlock to yield trident output. Different branch shares the
  209. convolution weight but uses different dilations to achieve multi-scale
  210. output.
  211. / stage3(b0) \
  212. x - stem - stage1 - stage2 - stage3(b1) - output
  213. \ stage3(b2) /
  214. Args:
  215. depth (int): Depth of resnet, from {50, 101, 152}.
  216. num_branch (int): Number of branches in TridentNet.
  217. test_branch_idx (int): In inference, all 3 branches will be used
  218. if `test_branch_idx==-1`, otherwise only branch with index
  219. `test_branch_idx` will be used.
  220. trident_dilations (tuple[int]): Dilations of different trident branch.
  221. len(trident_dilations) should be equal to num_branch.
  222. """ # noqa
  223. def __init__(self, depth, num_branch, test_branch_idx, trident_dilations,
  224. **kwargs):
  225. assert num_branch == len(trident_dilations)
  226. assert depth in (50, 101, 152)
  227. super(TridentResNet, self).__init__(depth, **kwargs)
  228. assert self.num_stages == 3
  229. self.test_branch_idx = test_branch_idx
  230. self.num_branch = num_branch
  231. last_stage_idx = self.num_stages - 1
  232. stride = self.strides[last_stage_idx]
  233. dilation = trident_dilations
  234. dcn = self.dcn if self.stage_with_dcn[last_stage_idx] else None
  235. if self.plugins is not None:
  236. stage_plugins = self.make_stage_plugins(self.plugins,
  237. last_stage_idx)
  238. else:
  239. stage_plugins = None
  240. planes = self.base_channels * 2**last_stage_idx
  241. res_layer = make_trident_res_layer(
  242. TridentBottleneck,
  243. inplanes=(self.block.expansion * self.base_channels *
  244. 2**(last_stage_idx - 1)),
  245. planes=planes,
  246. num_blocks=self.stage_blocks[last_stage_idx],
  247. stride=stride,
  248. trident_dilations=dilation,
  249. style=self.style,
  250. with_cp=self.with_cp,
  251. conv_cfg=self.conv_cfg,
  252. norm_cfg=self.norm_cfg,
  253. dcn=dcn,
  254. plugins=stage_plugins,
  255. test_branch_idx=self.test_branch_idx)
  256. layer_name = f'layer{last_stage_idx + 1}'
  257. self.__setattr__(layer_name, res_layer)
  258. self.res_layers.pop(last_stage_idx)
  259. self.res_layers.insert(last_stage_idx, layer_name)
  260. self._freeze_stages()