hrnet.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. import torch.nn as nn
  4. from mmcv.cnn import build_conv_layer, build_norm_layer
  5. from mmengine.model import BaseModule, ModuleList, Sequential
  6. from torch.nn.modules.batchnorm import _BatchNorm
  7. from mmdet.registry import MODELS
  8. from .resnet import BasicBlock, Bottleneck
  9. class HRModule(BaseModule):
  10. """High-Resolution Module for HRNet.
  11. In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
  12. is in this module.
  13. """
  14. def __init__(self,
  15. num_branches,
  16. blocks,
  17. num_blocks,
  18. in_channels,
  19. num_channels,
  20. multiscale_output=True,
  21. with_cp=False,
  22. conv_cfg=None,
  23. norm_cfg=dict(type='BN'),
  24. block_init_cfg=None,
  25. init_cfg=None):
  26. super(HRModule, self).__init__(init_cfg)
  27. self.block_init_cfg = block_init_cfg
  28. self._check_branches(num_branches, num_blocks, in_channels,
  29. num_channels)
  30. self.in_channels = in_channels
  31. self.num_branches = num_branches
  32. self.multiscale_output = multiscale_output
  33. self.norm_cfg = norm_cfg
  34. self.conv_cfg = conv_cfg
  35. self.with_cp = with_cp
  36. self.branches = self._make_branches(num_branches, blocks, num_blocks,
  37. num_channels)
  38. self.fuse_layers = self._make_fuse_layers()
  39. self.relu = nn.ReLU(inplace=False)
  40. def _check_branches(self, num_branches, num_blocks, in_channels,
  41. num_channels):
  42. if num_branches != len(num_blocks):
  43. error_msg = f'NUM_BRANCHES({num_branches}) ' \
  44. f'!= NUM_BLOCKS({len(num_blocks)})'
  45. raise ValueError(error_msg)
  46. if num_branches != len(num_channels):
  47. error_msg = f'NUM_BRANCHES({num_branches}) ' \
  48. f'!= NUM_CHANNELS({len(num_channels)})'
  49. raise ValueError(error_msg)
  50. if num_branches != len(in_channels):
  51. error_msg = f'NUM_BRANCHES({num_branches}) ' \
  52. f'!= NUM_INCHANNELS({len(in_channels)})'
  53. raise ValueError(error_msg)
  54. def _make_one_branch(self,
  55. branch_index,
  56. block,
  57. num_blocks,
  58. num_channels,
  59. stride=1):
  60. downsample = None
  61. if stride != 1 or \
  62. self.in_channels[branch_index] != \
  63. num_channels[branch_index] * block.expansion:
  64. downsample = nn.Sequential(
  65. build_conv_layer(
  66. self.conv_cfg,
  67. self.in_channels[branch_index],
  68. num_channels[branch_index] * block.expansion,
  69. kernel_size=1,
  70. stride=stride,
  71. bias=False),
  72. build_norm_layer(self.norm_cfg, num_channels[branch_index] *
  73. block.expansion)[1])
  74. layers = []
  75. layers.append(
  76. block(
  77. self.in_channels[branch_index],
  78. num_channels[branch_index],
  79. stride,
  80. downsample=downsample,
  81. with_cp=self.with_cp,
  82. norm_cfg=self.norm_cfg,
  83. conv_cfg=self.conv_cfg,
  84. init_cfg=self.block_init_cfg))
  85. self.in_channels[branch_index] = \
  86. num_channels[branch_index] * block.expansion
  87. for i in range(1, num_blocks[branch_index]):
  88. layers.append(
  89. block(
  90. self.in_channels[branch_index],
  91. num_channels[branch_index],
  92. with_cp=self.with_cp,
  93. norm_cfg=self.norm_cfg,
  94. conv_cfg=self.conv_cfg,
  95. init_cfg=self.block_init_cfg))
  96. return Sequential(*layers)
  97. def _make_branches(self, num_branches, block, num_blocks, num_channels):
  98. branches = []
  99. for i in range(num_branches):
  100. branches.append(
  101. self._make_one_branch(i, block, num_blocks, num_channels))
  102. return ModuleList(branches)
  103. def _make_fuse_layers(self):
  104. if self.num_branches == 1:
  105. return None
  106. num_branches = self.num_branches
  107. in_channels = self.in_channels
  108. fuse_layers = []
  109. num_out_branches = num_branches if self.multiscale_output else 1
  110. for i in range(num_out_branches):
  111. fuse_layer = []
  112. for j in range(num_branches):
  113. if j > i:
  114. fuse_layer.append(
  115. nn.Sequential(
  116. build_conv_layer(
  117. self.conv_cfg,
  118. in_channels[j],
  119. in_channels[i],
  120. kernel_size=1,
  121. stride=1,
  122. padding=0,
  123. bias=False),
  124. build_norm_layer(self.norm_cfg, in_channels[i])[1],
  125. nn.Upsample(
  126. scale_factor=2**(j - i), mode='nearest')))
  127. elif j == i:
  128. fuse_layer.append(None)
  129. else:
  130. conv_downsamples = []
  131. for k in range(i - j):
  132. if k == i - j - 1:
  133. conv_downsamples.append(
  134. nn.Sequential(
  135. build_conv_layer(
  136. self.conv_cfg,
  137. in_channels[j],
  138. in_channels[i],
  139. kernel_size=3,
  140. stride=2,
  141. padding=1,
  142. bias=False),
  143. build_norm_layer(self.norm_cfg,
  144. in_channels[i])[1]))
  145. else:
  146. conv_downsamples.append(
  147. nn.Sequential(
  148. build_conv_layer(
  149. self.conv_cfg,
  150. in_channels[j],
  151. in_channels[j],
  152. kernel_size=3,
  153. stride=2,
  154. padding=1,
  155. bias=False),
  156. build_norm_layer(self.norm_cfg,
  157. in_channels[j])[1],
  158. nn.ReLU(inplace=False)))
  159. fuse_layer.append(nn.Sequential(*conv_downsamples))
  160. fuse_layers.append(nn.ModuleList(fuse_layer))
  161. return nn.ModuleList(fuse_layers)
  162. def forward(self, x):
  163. """Forward function."""
  164. if self.num_branches == 1:
  165. return [self.branches[0](x[0])]
  166. for i in range(self.num_branches):
  167. x[i] = self.branches[i](x[i])
  168. x_fuse = []
  169. for i in range(len(self.fuse_layers)):
  170. y = 0
  171. for j in range(self.num_branches):
  172. if i == j:
  173. y += x[j]
  174. else:
  175. y += self.fuse_layers[i][j](x[j])
  176. x_fuse.append(self.relu(y))
  177. return x_fuse
  178. @MODELS.register_module()
  179. class HRNet(BaseModule):
  180. """HRNet backbone.
  181. `High-Resolution Representations for Labeling Pixels and Regions
  182. arXiv: <https://arxiv.org/abs/1904.04514>`_.
  183. Args:
  184. extra (dict): Detailed configuration for each stage of HRNet.
  185. There must be 4 stages, the configuration for each stage must have
  186. 5 keys:
  187. - num_modules(int): The number of HRModule in this stage.
  188. - num_branches(int): The number of branches in the HRModule.
  189. - block(str): The type of convolution block.
  190. - num_blocks(tuple): The number of blocks in each branch.
  191. The length must be equal to num_branches.
  192. - num_channels(tuple): The number of channels in each branch.
  193. The length must be equal to num_branches.
  194. in_channels (int): Number of input image channels. Default: 3.
  195. conv_cfg (dict): Dictionary to construct and config conv layer.
  196. norm_cfg (dict): Dictionary to construct and config norm layer.
  197. norm_eval (bool): Whether to set norm layers to eval mode, namely,
  198. freeze running stats (mean and var). Note: Effect on Batch Norm
  199. and its variants only. Default: True.
  200. with_cp (bool): Use checkpoint or not. Using checkpoint will save some
  201. memory while slowing down the training speed. Default: False.
  202. zero_init_residual (bool): Whether to use zero init for last norm layer
  203. in resblocks to let them behave as identity. Default: False.
  204. multiscale_output (bool): Whether to output multi-level features
  205. produced by multiple branches. If False, only the first level
  206. feature will be output. Default: True.
  207. pretrained (str, optional): Model pretrained path. Default: None.
  208. init_cfg (dict or list[dict], optional): Initialization config dict.
  209. Default: None.
  210. Example:
  211. >>> from mmdet.models import HRNet
  212. >>> import torch
  213. >>> extra = dict(
  214. >>> stage1=dict(
  215. >>> num_modules=1,
  216. >>> num_branches=1,
  217. >>> block='BOTTLENECK',
  218. >>> num_blocks=(4, ),
  219. >>> num_channels=(64, )),
  220. >>> stage2=dict(
  221. >>> num_modules=1,
  222. >>> num_branches=2,
  223. >>> block='BASIC',
  224. >>> num_blocks=(4, 4),
  225. >>> num_channels=(32, 64)),
  226. >>> stage3=dict(
  227. >>> num_modules=4,
  228. >>> num_branches=3,
  229. >>> block='BASIC',
  230. >>> num_blocks=(4, 4, 4),
  231. >>> num_channels=(32, 64, 128)),
  232. >>> stage4=dict(
  233. >>> num_modules=3,
  234. >>> num_branches=4,
  235. >>> block='BASIC',
  236. >>> num_blocks=(4, 4, 4, 4),
  237. >>> num_channels=(32, 64, 128, 256)))
  238. >>> self = HRNet(extra, in_channels=1)
  239. >>> self.eval()
  240. >>> inputs = torch.rand(1, 1, 32, 32)
  241. >>> level_outputs = self.forward(inputs)
  242. >>> for level_out in level_outputs:
  243. ... print(tuple(level_out.shape))
  244. (1, 32, 8, 8)
  245. (1, 64, 4, 4)
  246. (1, 128, 2, 2)
  247. (1, 256, 1, 1)
  248. """
  249. blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
  250. def __init__(self,
  251. extra,
  252. in_channels=3,
  253. conv_cfg=None,
  254. norm_cfg=dict(type='BN'),
  255. norm_eval=True,
  256. with_cp=False,
  257. zero_init_residual=False,
  258. multiscale_output=True,
  259. pretrained=None,
  260. init_cfg=None):
  261. super(HRNet, self).__init__(init_cfg)
  262. self.pretrained = pretrained
  263. assert not (init_cfg and pretrained), \
  264. 'init_cfg and pretrained cannot be specified at the same time'
  265. if isinstance(pretrained, str):
  266. warnings.warn('DeprecationWarning: pretrained is deprecated, '
  267. 'please use "init_cfg" instead')
  268. self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
  269. elif pretrained is None:
  270. if init_cfg is None:
  271. self.init_cfg = [
  272. dict(type='Kaiming', layer='Conv2d'),
  273. dict(
  274. type='Constant',
  275. val=1,
  276. layer=['_BatchNorm', 'GroupNorm'])
  277. ]
  278. else:
  279. raise TypeError('pretrained must be a str or None')
  280. # Assert configurations of 4 stages are in extra
  281. assert 'stage1' in extra and 'stage2' in extra \
  282. and 'stage3' in extra and 'stage4' in extra
  283. # Assert whether the length of `num_blocks` and `num_channels` are
  284. # equal to `num_branches`
  285. for i in range(4):
  286. cfg = extra[f'stage{i + 1}']
  287. assert len(cfg['num_blocks']) == cfg['num_branches'] and \
  288. len(cfg['num_channels']) == cfg['num_branches']
  289. self.extra = extra
  290. self.conv_cfg = conv_cfg
  291. self.norm_cfg = norm_cfg
  292. self.norm_eval = norm_eval
  293. self.with_cp = with_cp
  294. self.zero_init_residual = zero_init_residual
  295. # stem net
  296. self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
  297. self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
  298. self.conv1 = build_conv_layer(
  299. self.conv_cfg,
  300. in_channels,
  301. 64,
  302. kernel_size=3,
  303. stride=2,
  304. padding=1,
  305. bias=False)
  306. self.add_module(self.norm1_name, norm1)
  307. self.conv2 = build_conv_layer(
  308. self.conv_cfg,
  309. 64,
  310. 64,
  311. kernel_size=3,
  312. stride=2,
  313. padding=1,
  314. bias=False)
  315. self.add_module(self.norm2_name, norm2)
  316. self.relu = nn.ReLU(inplace=True)
  317. # stage 1
  318. self.stage1_cfg = self.extra['stage1']
  319. num_channels = self.stage1_cfg['num_channels'][0]
  320. block_type = self.stage1_cfg['block']
  321. num_blocks = self.stage1_cfg['num_blocks'][0]
  322. block = self.blocks_dict[block_type]
  323. stage1_out_channels = num_channels * block.expansion
  324. self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
  325. # stage 2
  326. self.stage2_cfg = self.extra['stage2']
  327. num_channels = self.stage2_cfg['num_channels']
  328. block_type = self.stage2_cfg['block']
  329. block = self.blocks_dict[block_type]
  330. num_channels = [channel * block.expansion for channel in num_channels]
  331. self.transition1 = self._make_transition_layer([stage1_out_channels],
  332. num_channels)
  333. self.stage2, pre_stage_channels = self._make_stage(
  334. self.stage2_cfg, num_channels)
  335. # stage 3
  336. self.stage3_cfg = self.extra['stage3']
  337. num_channels = self.stage3_cfg['num_channels']
  338. block_type = self.stage3_cfg['block']
  339. block = self.blocks_dict[block_type]
  340. num_channels = [channel * block.expansion for channel in num_channels]
  341. self.transition2 = self._make_transition_layer(pre_stage_channels,
  342. num_channels)
  343. self.stage3, pre_stage_channels = self._make_stage(
  344. self.stage3_cfg, num_channels)
  345. # stage 4
  346. self.stage4_cfg = self.extra['stage4']
  347. num_channels = self.stage4_cfg['num_channels']
  348. block_type = self.stage4_cfg['block']
  349. block = self.blocks_dict[block_type]
  350. num_channels = [channel * block.expansion for channel in num_channels]
  351. self.transition3 = self._make_transition_layer(pre_stage_channels,
  352. num_channels)
  353. self.stage4, pre_stage_channels = self._make_stage(
  354. self.stage4_cfg, num_channels, multiscale_output=multiscale_output)
  355. @property
  356. def norm1(self):
  357. """nn.Module: the normalization layer named "norm1" """
  358. return getattr(self, self.norm1_name)
  359. @property
  360. def norm2(self):
  361. """nn.Module: the normalization layer named "norm2" """
  362. return getattr(self, self.norm2_name)
  363. def _make_transition_layer(self, num_channels_pre_layer,
  364. num_channels_cur_layer):
  365. num_branches_cur = len(num_channels_cur_layer)
  366. num_branches_pre = len(num_channels_pre_layer)
  367. transition_layers = []
  368. for i in range(num_branches_cur):
  369. if i < num_branches_pre:
  370. if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
  371. transition_layers.append(
  372. nn.Sequential(
  373. build_conv_layer(
  374. self.conv_cfg,
  375. num_channels_pre_layer[i],
  376. num_channels_cur_layer[i],
  377. kernel_size=3,
  378. stride=1,
  379. padding=1,
  380. bias=False),
  381. build_norm_layer(self.norm_cfg,
  382. num_channels_cur_layer[i])[1],
  383. nn.ReLU(inplace=True)))
  384. else:
  385. transition_layers.append(None)
  386. else:
  387. conv_downsamples = []
  388. for j in range(i + 1 - num_branches_pre):
  389. in_channels = num_channels_pre_layer[-1]
  390. out_channels = num_channels_cur_layer[i] \
  391. if j == i - num_branches_pre else in_channels
  392. conv_downsamples.append(
  393. nn.Sequential(
  394. build_conv_layer(
  395. self.conv_cfg,
  396. in_channels,
  397. out_channels,
  398. kernel_size=3,
  399. stride=2,
  400. padding=1,
  401. bias=False),
  402. build_norm_layer(self.norm_cfg, out_channels)[1],
  403. nn.ReLU(inplace=True)))
  404. transition_layers.append(nn.Sequential(*conv_downsamples))
  405. return nn.ModuleList(transition_layers)
  406. def _make_layer(self, block, inplanes, planes, blocks, stride=1):
  407. downsample = None
  408. if stride != 1 or inplanes != planes * block.expansion:
  409. downsample = nn.Sequential(
  410. build_conv_layer(
  411. self.conv_cfg,
  412. inplanes,
  413. planes * block.expansion,
  414. kernel_size=1,
  415. stride=stride,
  416. bias=False),
  417. build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
  418. layers = []
  419. block_init_cfg = None
  420. if self.pretrained is None and not hasattr(
  421. self, 'init_cfg') and self.zero_init_residual:
  422. if block is BasicBlock:
  423. block_init_cfg = dict(
  424. type='Constant', val=0, override=dict(name='norm2'))
  425. elif block is Bottleneck:
  426. block_init_cfg = dict(
  427. type='Constant', val=0, override=dict(name='norm3'))
  428. layers.append(
  429. block(
  430. inplanes,
  431. planes,
  432. stride,
  433. downsample=downsample,
  434. with_cp=self.with_cp,
  435. norm_cfg=self.norm_cfg,
  436. conv_cfg=self.conv_cfg,
  437. init_cfg=block_init_cfg,
  438. ))
  439. inplanes = planes * block.expansion
  440. for i in range(1, blocks):
  441. layers.append(
  442. block(
  443. inplanes,
  444. planes,
  445. with_cp=self.with_cp,
  446. norm_cfg=self.norm_cfg,
  447. conv_cfg=self.conv_cfg,
  448. init_cfg=block_init_cfg))
  449. return Sequential(*layers)
  450. def _make_stage(self, layer_config, in_channels, multiscale_output=True):
  451. num_modules = layer_config['num_modules']
  452. num_branches = layer_config['num_branches']
  453. num_blocks = layer_config['num_blocks']
  454. num_channels = layer_config['num_channels']
  455. block = self.blocks_dict[layer_config['block']]
  456. hr_modules = []
  457. block_init_cfg = None
  458. if self.pretrained is None and not hasattr(
  459. self, 'init_cfg') and self.zero_init_residual:
  460. if block is BasicBlock:
  461. block_init_cfg = dict(
  462. type='Constant', val=0, override=dict(name='norm2'))
  463. elif block is Bottleneck:
  464. block_init_cfg = dict(
  465. type='Constant', val=0, override=dict(name='norm3'))
  466. for i in range(num_modules):
  467. # multi_scale_output is only used for the last module
  468. if not multiscale_output and i == num_modules - 1:
  469. reset_multiscale_output = False
  470. else:
  471. reset_multiscale_output = True
  472. hr_modules.append(
  473. HRModule(
  474. num_branches,
  475. block,
  476. num_blocks,
  477. in_channels,
  478. num_channels,
  479. reset_multiscale_output,
  480. with_cp=self.with_cp,
  481. norm_cfg=self.norm_cfg,
  482. conv_cfg=self.conv_cfg,
  483. block_init_cfg=block_init_cfg))
  484. return Sequential(*hr_modules), in_channels
  485. def forward(self, x):
  486. """Forward function."""
  487. x = self.conv1(x)
  488. x = self.norm1(x)
  489. x = self.relu(x)
  490. x = self.conv2(x)
  491. x = self.norm2(x)
  492. x = self.relu(x)
  493. x = self.layer1(x)
  494. x_list = []
  495. for i in range(self.stage2_cfg['num_branches']):
  496. if self.transition1[i] is not None:
  497. x_list.append(self.transition1[i](x))
  498. else:
  499. x_list.append(x)
  500. y_list = self.stage2(x_list)
  501. x_list = []
  502. for i in range(self.stage3_cfg['num_branches']):
  503. if self.transition2[i] is not None:
  504. x_list.append(self.transition2[i](y_list[-1]))
  505. else:
  506. x_list.append(y_list[i])
  507. y_list = self.stage3(x_list)
  508. x_list = []
  509. for i in range(self.stage4_cfg['num_branches']):
  510. if self.transition3[i] is not None:
  511. x_list.append(self.transition3[i](y_list[-1]))
  512. else:
  513. x_list.append(y_list[i])
  514. y_list = self.stage4(x_list)
  515. return y_list
  516. def train(self, mode=True):
  517. """Convert the model into training mode will keeping the normalization
  518. layer freezed."""
  519. super(HRNet, self).train(mode)
  520. if mode and self.norm_eval:
  521. for m in self.modules():
  522. # trick: eval have effect on BatchNorm only
  523. if isinstance(m, _BatchNorm):
  524. m.eval()