hourglass.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Sequence
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmcv.cnn import ConvModule
  7. from mmengine.model import BaseModule
  8. from mmdet.registry import MODELS
  9. from mmdet.utils import ConfigType, OptMultiConfig
  10. from ..layers import ResLayer
  11. from .resnet import BasicBlock
  12. class HourglassModule(BaseModule):
  13. """Hourglass Module for HourglassNet backbone.
  14. Generate module recursively and use BasicBlock as the base unit.
  15. Args:
  16. depth (int): Depth of current HourglassModule.
  17. stage_channels (list[int]): Feature channels of sub-modules in current
  18. and follow-up HourglassModule.
  19. stage_blocks (list[int]): Number of sub-modules stacked in current and
  20. follow-up HourglassModule.
  21. norm_cfg (ConfigType): Dictionary to construct and config norm layer.
  22. Defaults to `dict(type='BN', requires_grad=True)`
  23. upsample_cfg (ConfigType): Config dict for interpolate layer.
  24. Defaults to `dict(mode='nearest')`
  25. init_cfg (dict or ConfigDict, optional): the config to control the
  26. initialization.
  27. """
  28. def __init__(self,
  29. depth: int,
  30. stage_channels: List[int],
  31. stage_blocks: List[int],
  32. norm_cfg: ConfigType = dict(type='BN', requires_grad=True),
  33. upsample_cfg: ConfigType = dict(mode='nearest'),
  34. init_cfg: OptMultiConfig = None) -> None:
  35. super().__init__(init_cfg)
  36. self.depth = depth
  37. cur_block = stage_blocks[0]
  38. next_block = stage_blocks[1]
  39. cur_channel = stage_channels[0]
  40. next_channel = stage_channels[1]
  41. self.up1 = ResLayer(
  42. BasicBlock, cur_channel, cur_channel, cur_block, norm_cfg=norm_cfg)
  43. self.low1 = ResLayer(
  44. BasicBlock,
  45. cur_channel,
  46. next_channel,
  47. cur_block,
  48. stride=2,
  49. norm_cfg=norm_cfg)
  50. if self.depth > 1:
  51. self.low2 = HourglassModule(depth - 1, stage_channels[1:],
  52. stage_blocks[1:])
  53. else:
  54. self.low2 = ResLayer(
  55. BasicBlock,
  56. next_channel,
  57. next_channel,
  58. next_block,
  59. norm_cfg=norm_cfg)
  60. self.low3 = ResLayer(
  61. BasicBlock,
  62. next_channel,
  63. cur_channel,
  64. cur_block,
  65. norm_cfg=norm_cfg,
  66. downsample_first=False)
  67. self.up2 = F.interpolate
  68. self.upsample_cfg = upsample_cfg
  69. def forward(self, x: torch.Tensor) -> nn.Module:
  70. """Forward function."""
  71. up1 = self.up1(x)
  72. low1 = self.low1(x)
  73. low2 = self.low2(low1)
  74. low3 = self.low3(low2)
  75. # Fixing `scale factor` (e.g. 2) is common for upsampling, but
  76. # in some cases the spatial size is mismatched and error will arise.
  77. if 'scale_factor' in self.upsample_cfg:
  78. up2 = self.up2(low3, **self.upsample_cfg)
  79. else:
  80. shape = up1.shape[2:]
  81. up2 = self.up2(low3, size=shape, **self.upsample_cfg)
  82. return up1 + up2
  83. @MODELS.register_module()
  84. class HourglassNet(BaseModule):
  85. """HourglassNet backbone.
  86. Stacked Hourglass Networks for Human Pose Estimation.
  87. More details can be found in the `paper
  88. <https://arxiv.org/abs/1603.06937>`_ .
  89. Args:
  90. downsample_times (int): Downsample times in a HourglassModule.
  91. num_stacks (int): Number of HourglassModule modules stacked,
  92. 1 for Hourglass-52, 2 for Hourglass-104.
  93. stage_channels (Sequence[int]): Feature channel of each sub-module in a
  94. HourglassModule.
  95. stage_blocks (Sequence[int]): Number of sub-modules stacked in a
  96. HourglassModule.
  97. feat_channel (int): Feature channel of conv after a HourglassModule.
  98. norm_cfg (norm_cfg): Dictionary to construct and config norm layer.
  99. init_cfg (dict or ConfigDict, optional): the config to control the
  100. initialization.
  101. Example:
  102. >>> from mmdet.models import HourglassNet
  103. >>> import torch
  104. >>> self = HourglassNet()
  105. >>> self.eval()
  106. >>> inputs = torch.rand(1, 3, 511, 511)
  107. >>> level_outputs = self.forward(inputs)
  108. >>> for level_output in level_outputs:
  109. ... print(tuple(level_output.shape))
  110. (1, 256, 128, 128)
  111. (1, 256, 128, 128)
  112. """
  113. def __init__(self,
  114. downsample_times: int = 5,
  115. num_stacks: int = 2,
  116. stage_channels: Sequence = (256, 256, 384, 384, 384, 512),
  117. stage_blocks: Sequence = (2, 2, 2, 2, 2, 4),
  118. feat_channel: int = 256,
  119. norm_cfg: ConfigType = dict(type='BN', requires_grad=True),
  120. init_cfg: OptMultiConfig = None) -> None:
  121. assert init_cfg is None, 'To prevent abnormal initialization ' \
  122. 'behavior, init_cfg is not allowed to be set'
  123. super().__init__(init_cfg)
  124. self.num_stacks = num_stacks
  125. assert self.num_stacks >= 1
  126. assert len(stage_channels) == len(stage_blocks)
  127. assert len(stage_channels) > downsample_times
  128. cur_channel = stage_channels[0]
  129. self.stem = nn.Sequential(
  130. ConvModule(
  131. 3, cur_channel // 2, 7, padding=3, stride=2,
  132. norm_cfg=norm_cfg),
  133. ResLayer(
  134. BasicBlock,
  135. cur_channel // 2,
  136. cur_channel,
  137. 1,
  138. stride=2,
  139. norm_cfg=norm_cfg))
  140. self.hourglass_modules = nn.ModuleList([
  141. HourglassModule(downsample_times, stage_channels, stage_blocks)
  142. for _ in range(num_stacks)
  143. ])
  144. self.inters = ResLayer(
  145. BasicBlock,
  146. cur_channel,
  147. cur_channel,
  148. num_stacks - 1,
  149. norm_cfg=norm_cfg)
  150. self.conv1x1s = nn.ModuleList([
  151. ConvModule(
  152. cur_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
  153. for _ in range(num_stacks - 1)
  154. ])
  155. self.out_convs = nn.ModuleList([
  156. ConvModule(
  157. cur_channel, feat_channel, 3, padding=1, norm_cfg=norm_cfg)
  158. for _ in range(num_stacks)
  159. ])
  160. self.remap_convs = nn.ModuleList([
  161. ConvModule(
  162. feat_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
  163. for _ in range(num_stacks - 1)
  164. ])
  165. self.relu = nn.ReLU(inplace=True)
  166. def init_weights(self) -> None:
  167. """Init module weights."""
  168. # Training Centripetal Model needs to reset parameters for Conv2d
  169. super().init_weights()
  170. for m in self.modules():
  171. if isinstance(m, nn.Conv2d):
  172. m.reset_parameters()
  173. def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
  174. """Forward function."""
  175. inter_feat = self.stem(x)
  176. out_feats = []
  177. for ind in range(self.num_stacks):
  178. single_hourglass = self.hourglass_modules[ind]
  179. out_conv = self.out_convs[ind]
  180. hourglass_feat = single_hourglass(inter_feat)
  181. out_feat = out_conv(hourglass_feat)
  182. out_feats.append(out_feat)
  183. if ind < self.num_stacks - 1:
  184. inter_feat = self.conv1x1s[ind](
  185. inter_feat) + self.remap_convs[ind](
  186. out_feat)
  187. inter_feat = self.inters[ind](self.relu(inter_feat))
  188. return out_feats