res_layer.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional
  3. from mmcv.cnn import build_conv_layer, build_norm_layer
  4. from mmengine.model import BaseModule, Sequential
  5. from torch import Tensor
  6. from torch import nn as nn
  7. from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
  8. class ResLayer(Sequential):
  9. """ResLayer to build ResNet style backbone.
  10. Args:
  11. block (nn.Module): block used to build ResLayer.
  12. inplanes (int): inplanes of block.
  13. planes (int): planes of block.
  14. num_blocks (int): number of blocks.
  15. stride (int): stride of the first block. Defaults to 1
  16. avg_down (bool): Use AvgPool instead of stride conv when
  17. downsampling in the bottleneck. Defaults to False
  18. conv_cfg (dict): dictionary to construct and config conv layer.
  19. Defaults to None
  20. norm_cfg (dict): dictionary to construct and config norm layer.
  21. Defaults to dict(type='BN')
  22. downsample_first (bool): Downsample at the first block or last block.
  23. False for Hourglass, True for ResNet. Defaults to True
  24. """
  25. def __init__(self,
  26. block: BaseModule,
  27. inplanes: int,
  28. planes: int,
  29. num_blocks: int,
  30. stride: int = 1,
  31. avg_down: bool = False,
  32. conv_cfg: OptConfigType = None,
  33. norm_cfg: ConfigType = dict(type='BN'),
  34. downsample_first: bool = True,
  35. **kwargs) -> None:
  36. self.block = block
  37. downsample = None
  38. if stride != 1 or inplanes != planes * block.expansion:
  39. downsample = []
  40. conv_stride = stride
  41. if avg_down:
  42. conv_stride = 1
  43. downsample.append(
  44. nn.AvgPool2d(
  45. kernel_size=stride,
  46. stride=stride,
  47. ceil_mode=True,
  48. count_include_pad=False))
  49. downsample.extend([
  50. build_conv_layer(
  51. conv_cfg,
  52. inplanes,
  53. planes * block.expansion,
  54. kernel_size=1,
  55. stride=conv_stride,
  56. bias=False),
  57. build_norm_layer(norm_cfg, planes * block.expansion)[1]
  58. ])
  59. downsample = nn.Sequential(*downsample)
  60. layers = []
  61. if downsample_first:
  62. layers.append(
  63. block(
  64. inplanes=inplanes,
  65. planes=planes,
  66. stride=stride,
  67. downsample=downsample,
  68. conv_cfg=conv_cfg,
  69. norm_cfg=norm_cfg,
  70. **kwargs))
  71. inplanes = planes * block.expansion
  72. for _ in range(1, num_blocks):
  73. layers.append(
  74. block(
  75. inplanes=inplanes,
  76. planes=planes,
  77. stride=1,
  78. conv_cfg=conv_cfg,
  79. norm_cfg=norm_cfg,
  80. **kwargs))
  81. else: # downsample_first=False is for HourglassModule
  82. for _ in range(num_blocks - 1):
  83. layers.append(
  84. block(
  85. inplanes=inplanes,
  86. planes=inplanes,
  87. stride=1,
  88. conv_cfg=conv_cfg,
  89. norm_cfg=norm_cfg,
  90. **kwargs))
  91. layers.append(
  92. block(
  93. inplanes=inplanes,
  94. planes=planes,
  95. stride=stride,
  96. downsample=downsample,
  97. conv_cfg=conv_cfg,
  98. norm_cfg=norm_cfg,
  99. **kwargs))
  100. super().__init__(*layers)
  101. class SimplifiedBasicBlock(BaseModule):
  102. """Simplified version of original basic residual block. This is used in
  103. `SCNet <https://arxiv.org/abs/2012.10150>`_.
  104. - Norm layer is now optional
  105. - Last ReLU in forward function is removed
  106. """
  107. expansion = 1
  108. def __init__(self,
  109. inplanes: int,
  110. planes: int,
  111. stride: int = 1,
  112. dilation: int = 1,
  113. downsample: Optional[Sequential] = None,
  114. style: ConfigType = 'pytorch',
  115. with_cp: bool = False,
  116. conv_cfg: OptConfigType = None,
  117. norm_cfg: ConfigType = dict(type='BN'),
  118. dcn: OptConfigType = None,
  119. plugins: OptConfigType = None,
  120. init_cfg: OptMultiConfig = None) -> None:
  121. super().__init__(init_cfg=init_cfg)
  122. assert dcn is None, 'Not implemented yet.'
  123. assert plugins is None, 'Not implemented yet.'
  124. assert not with_cp, 'Not implemented yet.'
  125. self.with_norm = norm_cfg is not None
  126. with_bias = True if norm_cfg is None else False
  127. self.conv1 = build_conv_layer(
  128. conv_cfg,
  129. inplanes,
  130. planes,
  131. 3,
  132. stride=stride,
  133. padding=dilation,
  134. dilation=dilation,
  135. bias=with_bias)
  136. if self.with_norm:
  137. self.norm1_name, norm1 = build_norm_layer(
  138. norm_cfg, planes, postfix=1)
  139. self.add_module(self.norm1_name, norm1)
  140. self.conv2 = build_conv_layer(
  141. conv_cfg, planes, planes, 3, padding=1, bias=with_bias)
  142. if self.with_norm:
  143. self.norm2_name, norm2 = build_norm_layer(
  144. norm_cfg, planes, postfix=2)
  145. self.add_module(self.norm2_name, norm2)
  146. self.relu = nn.ReLU(inplace=True)
  147. self.downsample = downsample
  148. self.stride = stride
  149. self.dilation = dilation
  150. self.with_cp = with_cp
  151. @property
  152. def norm1(self) -> Optional[BaseModule]:
  153. """nn.Module: normalization layer after the first convolution layer"""
  154. return getattr(self, self.norm1_name) if self.with_norm else None
  155. @property
  156. def norm2(self) -> Optional[BaseModule]:
  157. """nn.Module: normalization layer after the second convolution layer"""
  158. return getattr(self, self.norm2_name) if self.with_norm else None
  159. def forward(self, x: Tensor) -> Tensor:
  160. """Forward function for SimplifiedBasicBlock."""
  161. identity = x
  162. out = self.conv1(x)
  163. if self.with_norm:
  164. out = self.norm1(out)
  165. out = self.relu(out)
  166. out = self.conv2(out)
  167. if self.with_norm:
  168. out = self.norm2(out)
  169. if self.downsample is not None:
  170. identity = self.downsample(x)
  171. out += identity
  172. return out