ssh.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Tuple
  3. import torch
  4. import torch.nn.functional as F
  5. from mmcv.cnn import ConvModule
  6. from mmengine.model import BaseModule
  7. from mmdet.registry import MODELS
  8. from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
  9. class SSHContextModule(BaseModule):
  10. """This is an implementation of `SSH context module` described in `SSH:
  11. Single Stage Headless Face Detector.
  12. <https://arxiv.org/pdf/1708.03979.pdf>`_.
  13. Args:
  14. in_channels (int): Number of input channels used at each scale.
  15. out_channels (int): Number of output channels used at each scale.
  16. conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
  17. convolution layer. Defaults to None.
  18. norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
  19. layer. Defaults to dict(type='BN').
  20. init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
  21. list[dict], optional): Initialization config dict.
  22. Defaults to None.
  23. """
  24. def __init__(self,
  25. in_channels: int,
  26. out_channels: int,
  27. conv_cfg: OptConfigType = None,
  28. norm_cfg: ConfigType = dict(type='BN'),
  29. init_cfg: OptMultiConfig = None):
  30. super().__init__(init_cfg=init_cfg)
  31. assert out_channels % 4 == 0
  32. self.in_channels = in_channels
  33. self.out_channels = out_channels
  34. self.conv5x5_1 = ConvModule(
  35. self.in_channels,
  36. self.out_channels // 4,
  37. 3,
  38. stride=1,
  39. padding=1,
  40. conv_cfg=conv_cfg,
  41. norm_cfg=norm_cfg,
  42. )
  43. self.conv5x5_2 = ConvModule(
  44. self.out_channels // 4,
  45. self.out_channels // 4,
  46. 3,
  47. stride=1,
  48. padding=1,
  49. conv_cfg=conv_cfg,
  50. norm_cfg=norm_cfg,
  51. act_cfg=None)
  52. self.conv7x7_2 = ConvModule(
  53. self.out_channels // 4,
  54. self.out_channels // 4,
  55. 3,
  56. stride=1,
  57. padding=1,
  58. conv_cfg=conv_cfg,
  59. norm_cfg=norm_cfg,
  60. )
  61. self.conv7x7_3 = ConvModule(
  62. self.out_channels // 4,
  63. self.out_channels // 4,
  64. 3,
  65. stride=1,
  66. padding=1,
  67. conv_cfg=conv_cfg,
  68. norm_cfg=norm_cfg,
  69. act_cfg=None,
  70. )
  71. def forward(self, x: torch.Tensor) -> tuple:
  72. conv5x5_1 = self.conv5x5_1(x)
  73. conv5x5 = self.conv5x5_2(conv5x5_1)
  74. conv7x7_2 = self.conv7x7_2(conv5x5_1)
  75. conv7x7 = self.conv7x7_3(conv7x7_2)
  76. return (conv5x5, conv7x7)
  77. class SSHDetModule(BaseModule):
  78. """This is an implementation of `SSH detection module` described in `SSH:
  79. Single Stage Headless Face Detector.
  80. <https://arxiv.org/pdf/1708.03979.pdf>`_.
  81. Args:
  82. in_channels (int): Number of input channels used at each scale.
  83. out_channels (int): Number of output channels used at each scale.
  84. conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
  85. convolution layer. Defaults to None.
  86. norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
  87. layer. Defaults to dict(type='BN').
  88. init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
  89. list[dict], optional): Initialization config dict.
  90. Defaults to None.
  91. """
  92. def __init__(self,
  93. in_channels: int,
  94. out_channels: int,
  95. conv_cfg: OptConfigType = None,
  96. norm_cfg: ConfigType = dict(type='BN'),
  97. init_cfg: OptMultiConfig = None):
  98. super().__init__(init_cfg=init_cfg)
  99. assert out_channels % 4 == 0
  100. self.in_channels = in_channels
  101. self.out_channels = out_channels
  102. self.conv3x3 = ConvModule(
  103. self.in_channels,
  104. self.out_channels // 2,
  105. 3,
  106. stride=1,
  107. padding=1,
  108. conv_cfg=conv_cfg,
  109. norm_cfg=norm_cfg,
  110. act_cfg=None)
  111. self.context_module = SSHContextModule(
  112. in_channels=self.in_channels,
  113. out_channels=self.out_channels,
  114. conv_cfg=conv_cfg,
  115. norm_cfg=norm_cfg)
  116. def forward(self, x: torch.Tensor) -> torch.Tensor:
  117. conv3x3 = self.conv3x3(x)
  118. conv5x5, conv7x7 = self.context_module(x)
  119. out = torch.cat([conv3x3, conv5x5, conv7x7], dim=1)
  120. out = F.relu(out)
  121. return out
  122. @MODELS.register_module()
  123. class SSH(BaseModule):
  124. """`SSH Neck` used in `SSH: Single Stage Headless Face Detector.
  125. <https://arxiv.org/pdf/1708.03979.pdf>`_.
  126. Args:
  127. num_scales (int): The number of scales / stages.
  128. in_channels (list[int]): The number of input channels per scale.
  129. out_channels (list[int]): The number of output channels per scale.
  130. conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
  131. convolution layer. Defaults to None.
  132. norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
  133. layer. Defaults to dict(type='BN').
  134. init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
  135. list[dict], optional): Initialization config dict.
  136. Example:
  137. >>> import torch
  138. >>> in_channels = [8, 16, 32, 64]
  139. >>> out_channels = [16, 32, 64, 128]
  140. >>> scales = [340, 170, 84, 43]
  141. >>> inputs = [torch.rand(1, c, s, s)
  142. ... for c, s in zip(in_channels, scales)]
  143. >>> self = SSH(num_scales=4, in_channels=in_channels,
  144. ... out_channels=out_channels)
  145. >>> outputs = self.forward(inputs)
  146. >>> for i in range(len(outputs)):
  147. ... print(f'outputs[{i}].shape = {outputs[i].shape}')
  148. outputs[0].shape = torch.Size([1, 16, 340, 340])
  149. outputs[1].shape = torch.Size([1, 32, 170, 170])
  150. outputs[2].shape = torch.Size([1, 64, 84, 84])
  151. outputs[3].shape = torch.Size([1, 128, 43, 43])
  152. """
  153. def __init__(self,
  154. num_scales: int,
  155. in_channels: List[int],
  156. out_channels: List[int],
  157. conv_cfg: OptConfigType = None,
  158. norm_cfg: ConfigType = dict(type='BN'),
  159. init_cfg: OptMultiConfig = dict(
  160. type='Xavier', layer='Conv2d', distribution='uniform')):
  161. super().__init__(init_cfg=init_cfg)
  162. assert (num_scales == len(in_channels) == len(out_channels))
  163. self.num_scales = num_scales
  164. self.in_channels = in_channels
  165. self.out_channels = out_channels
  166. for idx in range(self.num_scales):
  167. in_c, out_c = self.in_channels[idx], self.out_channels[idx]
  168. self.add_module(
  169. f'ssh_module{idx}',
  170. SSHDetModule(
  171. in_channels=in_c,
  172. out_channels=out_c,
  173. conv_cfg=conv_cfg,
  174. norm_cfg=norm_cfg))
  175. def forward(self, inputs: Tuple[torch.Tensor]) -> tuple:
  176. assert len(inputs) == self.num_scales
  177. outs = []
  178. for idx, x in enumerate(inputs):
  179. ssh_module = getattr(self, f'ssh_module{idx}')
  180. out = ssh_module(x)
  181. outs.append(out)
  182. return tuple(outs)