resnest.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. import torch.utils.checkpoint as cp
  7. from mmcv.cnn import build_conv_layer, build_norm_layer
  8. from mmengine.model import BaseModule
  9. from mmdet.registry import MODELS
  10. from ..layers import ResLayer
  11. from .resnet import Bottleneck as _Bottleneck
  12. from .resnet import ResNetV1d
  13. class RSoftmax(nn.Module):
  14. """Radix Softmax module in ``SplitAttentionConv2d``.
  15. Args:
  16. radix (int): Radix of input.
  17. groups (int): Groups of input.
  18. """
  19. def __init__(self, radix, groups):
  20. super().__init__()
  21. self.radix = radix
  22. self.groups = groups
  23. def forward(self, x):
  24. batch = x.size(0)
  25. if self.radix > 1:
  26. x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
  27. x = F.softmax(x, dim=1)
  28. x = x.reshape(batch, -1)
  29. else:
  30. x = torch.sigmoid(x)
  31. return x
  32. class SplitAttentionConv2d(BaseModule):
  33. """Split-Attention Conv2d in ResNeSt.
  34. Args:
  35. in_channels (int): Number of channels in the input feature map.
  36. channels (int): Number of intermediate channels.
  37. kernel_size (int | tuple[int]): Size of the convolution kernel.
  38. stride (int | tuple[int]): Stride of the convolution.
  39. padding (int | tuple[int]): Zero-padding added to both sides of
  40. dilation (int | tuple[int]): Spacing between kernel elements.
  41. groups (int): Number of blocked connections from input channels to
  42. output channels.
  43. groups (int): Same as nn.Conv2d.
  44. radix (int): Radix of SpltAtConv2d. Default: 2
  45. reduction_factor (int): Reduction factor of inter_channels. Default: 4.
  46. conv_cfg (dict): Config dict for convolution layer. Default: None,
  47. which means using conv2d.
  48. norm_cfg (dict): Config dict for normalization layer. Default: None.
  49. dcn (dict): Config dict for DCN. Default: None.
  50. init_cfg (dict or list[dict], optional): Initialization config dict.
  51. Default: None
  52. """
  53. def __init__(self,
  54. in_channels,
  55. channels,
  56. kernel_size,
  57. stride=1,
  58. padding=0,
  59. dilation=1,
  60. groups=1,
  61. radix=2,
  62. reduction_factor=4,
  63. conv_cfg=None,
  64. norm_cfg=dict(type='BN'),
  65. dcn=None,
  66. init_cfg=None):
  67. super(SplitAttentionConv2d, self).__init__(init_cfg)
  68. inter_channels = max(in_channels * radix // reduction_factor, 32)
  69. self.radix = radix
  70. self.groups = groups
  71. self.channels = channels
  72. self.with_dcn = dcn is not None
  73. self.dcn = dcn
  74. fallback_on_stride = False
  75. if self.with_dcn:
  76. fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
  77. if self.with_dcn and not fallback_on_stride:
  78. assert conv_cfg is None, 'conv_cfg must be None for DCN'
  79. conv_cfg = dcn
  80. self.conv = build_conv_layer(
  81. conv_cfg,
  82. in_channels,
  83. channels * radix,
  84. kernel_size,
  85. stride=stride,
  86. padding=padding,
  87. dilation=dilation,
  88. groups=groups * radix,
  89. bias=False)
  90. # To be consistent with original implementation, starting from 0
  91. self.norm0_name, norm0 = build_norm_layer(
  92. norm_cfg, channels * radix, postfix=0)
  93. self.add_module(self.norm0_name, norm0)
  94. self.relu = nn.ReLU(inplace=True)
  95. self.fc1 = build_conv_layer(
  96. None, channels, inter_channels, 1, groups=self.groups)
  97. self.norm1_name, norm1 = build_norm_layer(
  98. norm_cfg, inter_channels, postfix=1)
  99. self.add_module(self.norm1_name, norm1)
  100. self.fc2 = build_conv_layer(
  101. None, inter_channels, channels * radix, 1, groups=self.groups)
  102. self.rsoftmax = RSoftmax(radix, groups)
  103. @property
  104. def norm0(self):
  105. """nn.Module: the normalization layer named "norm0" """
  106. return getattr(self, self.norm0_name)
  107. @property
  108. def norm1(self):
  109. """nn.Module: the normalization layer named "norm1" """
  110. return getattr(self, self.norm1_name)
  111. def forward(self, x):
  112. x = self.conv(x)
  113. x = self.norm0(x)
  114. x = self.relu(x)
  115. batch, rchannel = x.shape[:2]
  116. batch = x.size(0)
  117. if self.radix > 1:
  118. splits = x.view(batch, self.radix, -1, *x.shape[2:])
  119. gap = splits.sum(dim=1)
  120. else:
  121. gap = x
  122. gap = F.adaptive_avg_pool2d(gap, 1)
  123. gap = self.fc1(gap)
  124. gap = self.norm1(gap)
  125. gap = self.relu(gap)
  126. atten = self.fc2(gap)
  127. atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
  128. if self.radix > 1:
  129. attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
  130. out = torch.sum(attens * splits, dim=1)
  131. else:
  132. out = atten * x
  133. return out.contiguous()
  134. class Bottleneck(_Bottleneck):
  135. """Bottleneck block for ResNeSt.
  136. Args:
  137. inplane (int): Input planes of this block.
  138. planes (int): Middle planes of this block.
  139. groups (int): Groups of conv2.
  140. base_width (int): Base of width in terms of base channels. Default: 4.
  141. base_channels (int): Base of channels for calculating width.
  142. Default: 64.
  143. radix (int): Radix of SpltAtConv2d. Default: 2
  144. reduction_factor (int): Reduction factor of inter_channels in
  145. SplitAttentionConv2d. Default: 4.
  146. avg_down_stride (bool): Whether to use average pool for stride in
  147. Bottleneck. Default: True.
  148. kwargs (dict): Key word arguments for base class.
  149. """
  150. expansion = 4
  151. def __init__(self,
  152. inplanes,
  153. planes,
  154. groups=1,
  155. base_width=4,
  156. base_channels=64,
  157. radix=2,
  158. reduction_factor=4,
  159. avg_down_stride=True,
  160. **kwargs):
  161. """Bottleneck block for ResNeSt."""
  162. super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
  163. if groups == 1:
  164. width = self.planes
  165. else:
  166. width = math.floor(self.planes *
  167. (base_width / base_channels)) * groups
  168. self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
  169. self.norm1_name, norm1 = build_norm_layer(
  170. self.norm_cfg, width, postfix=1)
  171. self.norm3_name, norm3 = build_norm_layer(
  172. self.norm_cfg, self.planes * self.expansion, postfix=3)
  173. self.conv1 = build_conv_layer(
  174. self.conv_cfg,
  175. self.inplanes,
  176. width,
  177. kernel_size=1,
  178. stride=self.conv1_stride,
  179. bias=False)
  180. self.add_module(self.norm1_name, norm1)
  181. self.with_modulated_dcn = False
  182. self.conv2 = SplitAttentionConv2d(
  183. width,
  184. width,
  185. kernel_size=3,
  186. stride=1 if self.avg_down_stride else self.conv2_stride,
  187. padding=self.dilation,
  188. dilation=self.dilation,
  189. groups=groups,
  190. radix=radix,
  191. reduction_factor=reduction_factor,
  192. conv_cfg=self.conv_cfg,
  193. norm_cfg=self.norm_cfg,
  194. dcn=self.dcn)
  195. delattr(self, self.norm2_name)
  196. if self.avg_down_stride:
  197. self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
  198. self.conv3 = build_conv_layer(
  199. self.conv_cfg,
  200. width,
  201. self.planes * self.expansion,
  202. kernel_size=1,
  203. bias=False)
  204. self.add_module(self.norm3_name, norm3)
  205. def forward(self, x):
  206. def _inner_forward(x):
  207. identity = x
  208. out = self.conv1(x)
  209. out = self.norm1(out)
  210. out = self.relu(out)
  211. if self.with_plugins:
  212. out = self.forward_plugin(out, self.after_conv1_plugin_names)
  213. out = self.conv2(out)
  214. if self.avg_down_stride:
  215. out = self.avd_layer(out)
  216. if self.with_plugins:
  217. out = self.forward_plugin(out, self.after_conv2_plugin_names)
  218. out = self.conv3(out)
  219. out = self.norm3(out)
  220. if self.with_plugins:
  221. out = self.forward_plugin(out, self.after_conv3_plugin_names)
  222. if self.downsample is not None:
  223. identity = self.downsample(x)
  224. out += identity
  225. return out
  226. if self.with_cp and x.requires_grad:
  227. out = cp.checkpoint(_inner_forward, x)
  228. else:
  229. out = _inner_forward(x)
  230. out = self.relu(out)
  231. return out
  232. @MODELS.register_module()
  233. class ResNeSt(ResNetV1d):
  234. """ResNeSt backbone.
  235. Args:
  236. groups (int): Number of groups of Bottleneck. Default: 1
  237. base_width (int): Base width of Bottleneck. Default: 4
  238. radix (int): Radix of SplitAttentionConv2d. Default: 2
  239. reduction_factor (int): Reduction factor of inter_channels in
  240. SplitAttentionConv2d. Default: 4.
  241. avg_down_stride (bool): Whether to use average pool for stride in
  242. Bottleneck. Default: True.
  243. kwargs (dict): Keyword arguments for ResNet.
  244. """
  245. arch_settings = {
  246. 50: (Bottleneck, (3, 4, 6, 3)),
  247. 101: (Bottleneck, (3, 4, 23, 3)),
  248. 152: (Bottleneck, (3, 8, 36, 3)),
  249. 200: (Bottleneck, (3, 24, 36, 3))
  250. }
  251. def __init__(self,
  252. groups=1,
  253. base_width=4,
  254. radix=2,
  255. reduction_factor=4,
  256. avg_down_stride=True,
  257. **kwargs):
  258. self.groups = groups
  259. self.base_width = base_width
  260. self.radix = radix
  261. self.reduction_factor = reduction_factor
  262. self.avg_down_stride = avg_down_stride
  263. super(ResNeSt, self).__init__(**kwargs)
  264. def make_res_layer(self, **kwargs):
  265. """Pack all blocks in a stage into a ``ResLayer``."""
  266. return ResLayer(
  267. groups=self.groups,
  268. base_width=self.base_width,
  269. base_channels=self.base_channels,
  270. radix=self.radix,
  271. reduction_factor=self.reduction_factor,
  272. avg_down_stride=self.avg_down_stride,
  273. **kwargs)