cspnext.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. from typing import Sequence, Tuple
  4. import torch.nn as nn
  5. from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
  6. from mmengine.model import BaseModule
  7. from torch import Tensor
  8. from torch.nn.modules.batchnorm import _BatchNorm
  9. from mmdet.registry import MODELS
  10. from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
  11. from ..layers import CSPLayer
  12. from .csp_darknet import SPPBottleneck
  13. @MODELS.register_module()
  14. class CSPNeXt(BaseModule):
  15. """CSPNeXt backbone used in RTMDet.
  16. Args:
  17. arch (str): Architecture of CSPNeXt, from {P5, P6}.
  18. Defaults to P5.
  19. expand_ratio (float): Ratio to adjust the number of channels of the
  20. hidden layer. Defaults to 0.5.
  21. deepen_factor (float): Depth multiplier, multiply number of
  22. blocks in CSP layer by this amount. Defaults to 1.0.
  23. widen_factor (float): Width multiplier, multiply number of
  24. channels in each layer by this amount. Defaults to 1.0.
  25. out_indices (Sequence[int]): Output from which stages.
  26. Defaults to (2, 3, 4).
  27. frozen_stages (int): Stages to be frozen (stop grad and set eval
  28. mode). -1 means not freezing any parameters. Defaults to -1.
  29. use_depthwise (bool): Whether to use depthwise separable convolution.
  30. Defaults to False.
  31. arch_ovewrite (list): Overwrite default arch settings.
  32. Defaults to None.
  33. spp_kernel_sizes: (tuple[int]): Sequential of kernel sizes of SPP
  34. layers. Defaults to (5, 9, 13).
  35. channel_attention (bool): Whether to add channel attention in each
  36. stage. Defaults to True.
  37. conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
  38. convolution layer. Defaults to None.
  39. norm_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and
  40. config norm layer. Defaults to dict(type='BN', requires_grad=True).
  41. act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
  42. Defaults to dict(type='SiLU').
  43. norm_eval (bool): Whether to set norm layers to eval mode, namely,
  44. freeze running stats (mean and var). Note: Effect on Batch Norm
  45. and its variants only.
  46. init_cfg (:obj:`ConfigDict` or dict or list[dict] or
  47. list[:obj:`ConfigDict`]): Initialization config dict.
  48. """
  49. # From left to right:
  50. # in_channels, out_channels, num_blocks, add_identity, use_spp
  51. arch_settings = {
  52. 'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False],
  53. [256, 512, 6, True, False], [512, 1024, 3, False, True]],
  54. 'P6': [[64, 128, 3, True, False], [128, 256, 6, True, False],
  55. [256, 512, 6, True, False], [512, 768, 3, True, False],
  56. [768, 1024, 3, False, True]]
  57. }
  58. def __init__(
  59. self,
  60. arch: str = 'P5',
  61. deepen_factor: float = 1.0,
  62. widen_factor: float = 1.0,
  63. out_indices: Sequence[int] = (2, 3, 4),
  64. frozen_stages: int = -1,
  65. use_depthwise: bool = False,
  66. expand_ratio: float = 0.5,
  67. arch_ovewrite: dict = None,
  68. spp_kernel_sizes: Sequence[int] = (5, 9, 13),
  69. channel_attention: bool = True,
  70. conv_cfg: OptConfigType = None,
  71. norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
  72. act_cfg: ConfigType = dict(type='SiLU'),
  73. norm_eval: bool = False,
  74. init_cfg: OptMultiConfig = dict(
  75. type='Kaiming',
  76. layer='Conv2d',
  77. a=math.sqrt(5),
  78. distribution='uniform',
  79. mode='fan_in',
  80. nonlinearity='leaky_relu')
  81. ) -> None:
  82. super().__init__(init_cfg=init_cfg)
  83. arch_setting = self.arch_settings[arch]
  84. if arch_ovewrite:
  85. arch_setting = arch_ovewrite
  86. assert set(out_indices).issubset(
  87. i for i in range(len(arch_setting) + 1))
  88. if frozen_stages not in range(-1, len(arch_setting) + 1):
  89. raise ValueError('frozen_stages must be in range(-1, '
  90. 'len(arch_setting) + 1). But received '
  91. f'{frozen_stages}')
  92. self.out_indices = out_indices
  93. self.frozen_stages = frozen_stages
  94. self.use_depthwise = use_depthwise
  95. self.norm_eval = norm_eval
  96. conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
  97. self.stem = nn.Sequential(
  98. ConvModule(
  99. 3,
  100. int(arch_setting[0][0] * widen_factor // 2),
  101. 3,
  102. padding=1,
  103. stride=2,
  104. norm_cfg=norm_cfg,
  105. act_cfg=act_cfg),
  106. ConvModule(
  107. int(arch_setting[0][0] * widen_factor // 2),
  108. int(arch_setting[0][0] * widen_factor // 2),
  109. 3,
  110. padding=1,
  111. stride=1,
  112. norm_cfg=norm_cfg,
  113. act_cfg=act_cfg),
  114. ConvModule(
  115. int(arch_setting[0][0] * widen_factor // 2),
  116. int(arch_setting[0][0] * widen_factor),
  117. 3,
  118. padding=1,
  119. stride=1,
  120. norm_cfg=norm_cfg,
  121. act_cfg=act_cfg))
  122. self.layers = ['stem']
  123. for i, (in_channels, out_channels, num_blocks, add_identity,
  124. use_spp) in enumerate(arch_setting):
  125. in_channels = int(in_channels * widen_factor)
  126. out_channels = int(out_channels * widen_factor)
  127. num_blocks = max(round(num_blocks * deepen_factor), 1)
  128. stage = []
  129. conv_layer = conv(
  130. in_channels,
  131. out_channels,
  132. 3,
  133. stride=2,
  134. padding=1,
  135. conv_cfg=conv_cfg,
  136. norm_cfg=norm_cfg,
  137. act_cfg=act_cfg)
  138. stage.append(conv_layer)
  139. if use_spp:
  140. spp = SPPBottleneck(
  141. out_channels,
  142. out_channels,
  143. kernel_sizes=spp_kernel_sizes,
  144. conv_cfg=conv_cfg,
  145. norm_cfg=norm_cfg,
  146. act_cfg=act_cfg)
  147. stage.append(spp)
  148. csp_layer = CSPLayer(
  149. out_channels,
  150. out_channels,
  151. num_blocks=num_blocks,
  152. add_identity=add_identity,
  153. use_depthwise=use_depthwise,
  154. use_cspnext_block=True,
  155. expand_ratio=expand_ratio,
  156. channel_attention=channel_attention,
  157. conv_cfg=conv_cfg,
  158. norm_cfg=norm_cfg,
  159. act_cfg=act_cfg)
  160. stage.append(csp_layer)
  161. self.add_module(f'stage{i + 1}', nn.Sequential(*stage))
  162. self.layers.append(f'stage{i + 1}')
  163. def _freeze_stages(self) -> None:
  164. if self.frozen_stages >= 0:
  165. for i in range(self.frozen_stages + 1):
  166. m = getattr(self, self.layers[i])
  167. m.eval()
  168. for param in m.parameters():
  169. param.requires_grad = False
  170. def train(self, mode=True) -> None:
  171. super().train(mode)
  172. self._freeze_stages()
  173. if mode and self.norm_eval:
  174. for m in self.modules():
  175. if isinstance(m, _BatchNorm):
  176. m.eval()
  177. def forward(self, x: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]:
  178. outs = []
  179. for i, layer_name in enumerate(self.layers):
  180. layer = getattr(self, layer_name)
  181. x = layer(x)
  182. if i in self.out_indices:
  183. outs.append(x)
  184. return tuple(outs)