123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List, Tuple
- import torch
- import torch.nn.functional as F
- from mmcv.cnn import ConvModule
- from mmengine.model import BaseModule
- from mmdet.registry import MODELS
- from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
- class SSHContextModule(BaseModule):
- """This is an implementation of `SSH context module` described in `SSH:
- Single Stage Headless Face Detector.
- <https://arxiv.org/pdf/1708.03979.pdf>`_.
- Args:
- in_channels (int): Number of input channels used at each scale.
- out_channels (int): Number of output channels used at each scale.
- conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
- convolution layer. Defaults to None.
- norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
- layer. Defaults to dict(type='BN').
- init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
- list[dict], optional): Initialization config dict.
- Defaults to None.
- """
- def __init__(self,
- in_channels: int,
- out_channels: int,
- conv_cfg: OptConfigType = None,
- norm_cfg: ConfigType = dict(type='BN'),
- init_cfg: OptMultiConfig = None):
- super().__init__(init_cfg=init_cfg)
- assert out_channels % 4 == 0
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.conv5x5_1 = ConvModule(
- self.in_channels,
- self.out_channels // 4,
- 3,
- stride=1,
- padding=1,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- )
- self.conv5x5_2 = ConvModule(
- self.out_channels // 4,
- self.out_channels // 4,
- 3,
- stride=1,
- padding=1,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=None)
- self.conv7x7_2 = ConvModule(
- self.out_channels // 4,
- self.out_channels // 4,
- 3,
- stride=1,
- padding=1,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- )
- self.conv7x7_3 = ConvModule(
- self.out_channels // 4,
- self.out_channels // 4,
- 3,
- stride=1,
- padding=1,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=None,
- )
- def forward(self, x: torch.Tensor) -> tuple:
- conv5x5_1 = self.conv5x5_1(x)
- conv5x5 = self.conv5x5_2(conv5x5_1)
- conv7x7_2 = self.conv7x7_2(conv5x5_1)
- conv7x7 = self.conv7x7_3(conv7x7_2)
- return (conv5x5, conv7x7)
- class SSHDetModule(BaseModule):
- """This is an implementation of `SSH detection module` described in `SSH:
- Single Stage Headless Face Detector.
- <https://arxiv.org/pdf/1708.03979.pdf>`_.
- Args:
- in_channels (int): Number of input channels used at each scale.
- out_channels (int): Number of output channels used at each scale.
- conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
- convolution layer. Defaults to None.
- norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
- layer. Defaults to dict(type='BN').
- init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
- list[dict], optional): Initialization config dict.
- Defaults to None.
- """
- def __init__(self,
- in_channels: int,
- out_channels: int,
- conv_cfg: OptConfigType = None,
- norm_cfg: ConfigType = dict(type='BN'),
- init_cfg: OptMultiConfig = None):
- super().__init__(init_cfg=init_cfg)
- assert out_channels % 4 == 0
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.conv3x3 = ConvModule(
- self.in_channels,
- self.out_channels // 2,
- 3,
- stride=1,
- padding=1,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=None)
- self.context_module = SSHContextModule(
- in_channels=self.in_channels,
- out_channels=self.out_channels,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- conv3x3 = self.conv3x3(x)
- conv5x5, conv7x7 = self.context_module(x)
- out = torch.cat([conv3x3, conv5x5, conv7x7], dim=1)
- out = F.relu(out)
- return out
- @MODELS.register_module()
- class SSH(BaseModule):
- """`SSH Neck` used in `SSH: Single Stage Headless Face Detector.
- <https://arxiv.org/pdf/1708.03979.pdf>`_.
- Args:
- num_scales (int): The number of scales / stages.
- in_channels (list[int]): The number of input channels per scale.
- out_channels (list[int]): The number of output channels per scale.
- conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
- convolution layer. Defaults to None.
- norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
- layer. Defaults to dict(type='BN').
- init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
- list[dict], optional): Initialization config dict.
- Example:
- >>> import torch
- >>> in_channels = [8, 16, 32, 64]
- >>> out_channels = [16, 32, 64, 128]
- >>> scales = [340, 170, 84, 43]
- >>> inputs = [torch.rand(1, c, s, s)
- ... for c, s in zip(in_channels, scales)]
- >>> self = SSH(num_scales=4, in_channels=in_channels,
- ... out_channels=out_channels)
- >>> outputs = self.forward(inputs)
- >>> for i in range(len(outputs)):
- ... print(f'outputs[{i}].shape = {outputs[i].shape}')
- outputs[0].shape = torch.Size([1, 16, 340, 340])
- outputs[1].shape = torch.Size([1, 32, 170, 170])
- outputs[2].shape = torch.Size([1, 64, 84, 84])
- outputs[3].shape = torch.Size([1, 128, 43, 43])
- """
- def __init__(self,
- num_scales: int,
- in_channels: List[int],
- out_channels: List[int],
- conv_cfg: OptConfigType = None,
- norm_cfg: ConfigType = dict(type='BN'),
- init_cfg: OptMultiConfig = dict(
- type='Xavier', layer='Conv2d', distribution='uniform')):
- super().__init__(init_cfg=init_cfg)
- assert (num_scales == len(in_channels) == len(out_channels))
- self.num_scales = num_scales
- self.in_channels = in_channels
- self.out_channels = out_channels
- for idx in range(self.num_scales):
- in_c, out_c = self.in_channels[idx], self.out_channels[idx]
- self.add_module(
- f'ssh_module{idx}',
- SSHDetModule(
- in_channels=in_c,
- out_channels=out_c,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg))
- def forward(self, inputs: Tuple[torch.Tensor]) -> tuple:
- assert len(inputs) == self.num_scales
- outs = []
- for idx, x in enumerate(inputs):
- ssh_module = getattr(self, f'ssh_module{idx}')
- out = ssh_module(x)
- outs.append(out)
- return tuple(outs)
|