123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List, Sequence
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from mmcv.cnn import ConvModule
- from mmengine.model import BaseModule
- from mmdet.registry import MODELS
- from mmdet.utils import ConfigType, OptMultiConfig
- from ..layers import ResLayer
- from .resnet import BasicBlock
- class HourglassModule(BaseModule):
- """Hourglass Module for HourglassNet backbone.
- Generate module recursively and use BasicBlock as the base unit.
- Args:
- depth (int): Depth of current HourglassModule.
- stage_channels (list[int]): Feature channels of sub-modules in current
- and follow-up HourglassModule.
- stage_blocks (list[int]): Number of sub-modules stacked in current and
- follow-up HourglassModule.
- norm_cfg (ConfigType): Dictionary to construct and config norm layer.
- Defaults to `dict(type='BN', requires_grad=True)`
- upsample_cfg (ConfigType): Config dict for interpolate layer.
- Defaults to `dict(mode='nearest')`
- init_cfg (dict or ConfigDict, optional): the config to control the
- initialization.
- """
- def __init__(self,
- depth: int,
- stage_channels: List[int],
- stage_blocks: List[int],
- norm_cfg: ConfigType = dict(type='BN', requires_grad=True),
- upsample_cfg: ConfigType = dict(mode='nearest'),
- init_cfg: OptMultiConfig = None) -> None:
- super().__init__(init_cfg)
- self.depth = depth
- cur_block = stage_blocks[0]
- next_block = stage_blocks[1]
- cur_channel = stage_channels[0]
- next_channel = stage_channels[1]
- self.up1 = ResLayer(
- BasicBlock, cur_channel, cur_channel, cur_block, norm_cfg=norm_cfg)
- self.low1 = ResLayer(
- BasicBlock,
- cur_channel,
- next_channel,
- cur_block,
- stride=2,
- norm_cfg=norm_cfg)
- if self.depth > 1:
- self.low2 = HourglassModule(depth - 1, stage_channels[1:],
- stage_blocks[1:])
- else:
- self.low2 = ResLayer(
- BasicBlock,
- next_channel,
- next_channel,
- next_block,
- norm_cfg=norm_cfg)
- self.low3 = ResLayer(
- BasicBlock,
- next_channel,
- cur_channel,
- cur_block,
- norm_cfg=norm_cfg,
- downsample_first=False)
- self.up2 = F.interpolate
- self.upsample_cfg = upsample_cfg
- def forward(self, x: torch.Tensor) -> nn.Module:
- """Forward function."""
- up1 = self.up1(x)
- low1 = self.low1(x)
- low2 = self.low2(low1)
- low3 = self.low3(low2)
- # Fixing `scale factor` (e.g. 2) is common for upsampling, but
- # in some cases the spatial size is mismatched and error will arise.
- if 'scale_factor' in self.upsample_cfg:
- up2 = self.up2(low3, **self.upsample_cfg)
- else:
- shape = up1.shape[2:]
- up2 = self.up2(low3, size=shape, **self.upsample_cfg)
- return up1 + up2
- @MODELS.register_module()
- class HourglassNet(BaseModule):
- """HourglassNet backbone.
- Stacked Hourglass Networks for Human Pose Estimation.
- More details can be found in the `paper
- <https://arxiv.org/abs/1603.06937>`_ .
- Args:
- downsample_times (int): Downsample times in a HourglassModule.
- num_stacks (int): Number of HourglassModule modules stacked,
- 1 for Hourglass-52, 2 for Hourglass-104.
- stage_channels (Sequence[int]): Feature channel of each sub-module in a
- HourglassModule.
- stage_blocks (Sequence[int]): Number of sub-modules stacked in a
- HourglassModule.
- feat_channel (int): Feature channel of conv after a HourglassModule.
- norm_cfg (norm_cfg): Dictionary to construct and config norm layer.
- init_cfg (dict or ConfigDict, optional): the config to control the
- initialization.
- Example:
- >>> from mmdet.models import HourglassNet
- >>> import torch
- >>> self = HourglassNet()
- >>> self.eval()
- >>> inputs = torch.rand(1, 3, 511, 511)
- >>> level_outputs = self.forward(inputs)
- >>> for level_output in level_outputs:
- ... print(tuple(level_output.shape))
- (1, 256, 128, 128)
- (1, 256, 128, 128)
- """
- def __init__(self,
- downsample_times: int = 5,
- num_stacks: int = 2,
- stage_channels: Sequence = (256, 256, 384, 384, 384, 512),
- stage_blocks: Sequence = (2, 2, 2, 2, 2, 4),
- feat_channel: int = 256,
- norm_cfg: ConfigType = dict(type='BN', requires_grad=True),
- init_cfg: OptMultiConfig = None) -> None:
- assert init_cfg is None, 'To prevent abnormal initialization ' \
- 'behavior, init_cfg is not allowed to be set'
- super().__init__(init_cfg)
- self.num_stacks = num_stacks
- assert self.num_stacks >= 1
- assert len(stage_channels) == len(stage_blocks)
- assert len(stage_channels) > downsample_times
- cur_channel = stage_channels[0]
- self.stem = nn.Sequential(
- ConvModule(
- 3, cur_channel // 2, 7, padding=3, stride=2,
- norm_cfg=norm_cfg),
- ResLayer(
- BasicBlock,
- cur_channel // 2,
- cur_channel,
- 1,
- stride=2,
- norm_cfg=norm_cfg))
- self.hourglass_modules = nn.ModuleList([
- HourglassModule(downsample_times, stage_channels, stage_blocks)
- for _ in range(num_stacks)
- ])
- self.inters = ResLayer(
- BasicBlock,
- cur_channel,
- cur_channel,
- num_stacks - 1,
- norm_cfg=norm_cfg)
- self.conv1x1s = nn.ModuleList([
- ConvModule(
- cur_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
- for _ in range(num_stacks - 1)
- ])
- self.out_convs = nn.ModuleList([
- ConvModule(
- cur_channel, feat_channel, 3, padding=1, norm_cfg=norm_cfg)
- for _ in range(num_stacks)
- ])
- self.remap_convs = nn.ModuleList([
- ConvModule(
- feat_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
- for _ in range(num_stacks - 1)
- ])
- self.relu = nn.ReLU(inplace=True)
- def init_weights(self) -> None:
- """Init module weights."""
- # Training Centripetal Model needs to reset parameters for Conv2d
- super().init_weights()
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- m.reset_parameters()
- def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
- """Forward function."""
- inter_feat = self.stem(x)
- out_feats = []
- for ind in range(self.num_stacks):
- single_hourglass = self.hourglass_modules[ind]
- out_conv = self.out_convs[ind]
- hourglass_feat = single_hourglass(inter_feat)
- out_feat = out_conv(hourglass_feat)
- out_feats.append(out_feat)
- if ind < self.num_stacks - 1:
- inter_feat = self.conv1x1s[ind](
- inter_feat) + self.remap_convs[ind](
- out_feat)
- inter_feat = self.inters[ind](self.relu(inter_feat))
- return out_feats
|