test_hrformer.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmpose.models.backbones.hrformer import (HRFomerModule, HRFormer,
  5. HRFormerBlock)
  6. class TestHrformer(TestCase):
  7. def test_hrformer_module(self):
  8. norm_cfg = dict(type='BN')
  9. block = HRFormerBlock
  10. # Test multiscale forward
  11. num_channles = (32, 64)
  12. num_inchannels = [c * block.expansion for c in num_channles]
  13. hrmodule = HRFomerModule(
  14. num_branches=2,
  15. block=block,
  16. num_blocks=(2, 2),
  17. num_inchannels=num_inchannels,
  18. num_channels=num_channles,
  19. num_heads=(1, 2),
  20. num_window_sizes=(7, 7),
  21. num_mlp_ratios=(4, 4),
  22. drop_paths=(0., 0.),
  23. norm_cfg=norm_cfg)
  24. feats = [
  25. torch.randn(1, num_inchannels[0], 64, 64),
  26. torch.randn(1, num_inchannels[1], 32, 32)
  27. ]
  28. feats = hrmodule(feats)
  29. self.assertGreater(len(str(hrmodule)), 0)
  30. self.assertEqual(len(feats), 2)
  31. self.assertEqual(feats[0].shape,
  32. torch.Size([1, num_inchannels[0], 64, 64]))
  33. self.assertEqual(feats[1].shape,
  34. torch.Size([1, num_inchannels[1], 32, 32]))
  35. # Test single scale forward
  36. num_channles = (32, 64)
  37. in_channels = [c * block.expansion for c in num_channles]
  38. hrmodule = HRFomerModule(
  39. num_branches=2,
  40. block=block,
  41. num_blocks=(2, 2),
  42. num_inchannels=num_inchannels,
  43. num_channels=num_channles,
  44. num_heads=(1, 2),
  45. num_window_sizes=(7, 7),
  46. num_mlp_ratios=(4, 4),
  47. drop_paths=(0., 0.),
  48. norm_cfg=norm_cfg,
  49. multiscale_output=False,
  50. )
  51. feats = [
  52. torch.randn(1, in_channels[0], 64, 64),
  53. torch.randn(1, in_channels[1], 32, 32)
  54. ]
  55. feats = hrmodule(feats)
  56. self.assertEqual(len(feats), 1)
  57. self.assertEqual(feats[0].shape,
  58. torch.Size([1, in_channels[0], 64, 64]))
  59. # Test single branch HRFormer module
  60. hrmodule = HRFomerModule(
  61. num_branches=1,
  62. block=block,
  63. num_blocks=(1, ),
  64. num_inchannels=[num_inchannels[0]],
  65. num_channels=[num_channles[0]],
  66. num_heads=(1, ),
  67. num_window_sizes=(7, ),
  68. num_mlp_ratios=(4, ),
  69. drop_paths=(0.1, ),
  70. norm_cfg=norm_cfg,
  71. )
  72. feats = [
  73. torch.randn(1, in_channels[0], 64, 64),
  74. ]
  75. feats = hrmodule(feats)
  76. self.assertEqual(len(feats), 1)
  77. self.assertEqual(feats[0].shape,
  78. torch.Size([1, in_channels[0], 64, 64]))
  79. # Value tests
  80. kwargs = dict(
  81. num_branches=2,
  82. block=block,
  83. num_blocks=(2, 2),
  84. num_inchannels=num_inchannels,
  85. num_channels=num_channles,
  86. num_heads=(1, 2),
  87. num_window_sizes=(7, 7),
  88. num_mlp_ratios=(4, 4),
  89. drop_paths=(0.1, 0.1),
  90. norm_cfg=norm_cfg,
  91. )
  92. with self.assertRaises(ValueError):
  93. # len(num_blocks) should equal num_branches
  94. kwargs['num_blocks'] = [2, 2, 2]
  95. HRFomerModule(**kwargs)
  96. kwargs['num_blocks'] = [2, 2]
  97. with self.assertRaises(ValueError):
  98. # len(num_blocks) should equal num_branches
  99. kwargs['num_channels'] = [2]
  100. HRFomerModule(**kwargs)
  101. kwargs['num_channels'] = [2, 2]
  102. with self.assertRaises(ValueError):
  103. # len(num_blocks) should equal num_branches
  104. kwargs['num_inchannels'] = [2]
  105. HRFomerModule(**kwargs)
  106. kwargs['num_inchannels'] = [2, 2]
  107. def test_hrformer_backbone(self):
  108. norm_cfg = dict(type='BN')
  109. # only have 3 stages
  110. extra = dict(
  111. drop_path_rate=0.2,
  112. stage1=dict(
  113. num_modules=1,
  114. num_branches=1,
  115. block='BOTTLENECK',
  116. num_blocks=(2, ),
  117. num_channels=(64, )),
  118. stage2=dict(
  119. num_modules=1,
  120. num_branches=2,
  121. block='HRFORMERBLOCK',
  122. window_sizes=(7, 7),
  123. num_heads=(1, 2),
  124. mlp_ratios=(4, 4),
  125. num_blocks=(2, 2),
  126. num_channels=(32, 64)),
  127. stage3=dict(
  128. num_modules=4,
  129. num_branches=3,
  130. block='HRFORMERBLOCK',
  131. window_sizes=(7, 7, 7),
  132. num_heads=(1, 2, 4),
  133. mlp_ratios=(4, 4, 4),
  134. num_blocks=(2, 2, 2),
  135. num_channels=(32, 64, 128)),
  136. stage4=dict(
  137. num_modules=3,
  138. num_branches=4,
  139. block='HRFORMERBLOCK',
  140. window_sizes=(7, 7, 7, 7),
  141. num_heads=(1, 2, 4, 8),
  142. mlp_ratios=(4, 4, 4, 4),
  143. num_blocks=(2, 2, 2, 2),
  144. num_channels=(32, 64, 128, 256),
  145. multiscale_output=True))
  146. with self.assertRaises(ValueError):
  147. # len(num_blocks) should equal num_branches
  148. extra['stage4']['num_branches'] = 3
  149. HRFormer(extra=extra)
  150. extra['stage4']['num_branches'] = 4
  151. # Test HRFormer-S
  152. model = HRFormer(extra=extra, norm_cfg=norm_cfg)
  153. model.init_weights()
  154. model.train()
  155. imgs = torch.randn(1, 3, 64, 64)
  156. feats = model(imgs)
  157. self.assertEqual(len(feats), 4)
  158. self.assertEqual(feats[0].shape, torch.Size([1, 32, 16, 16]))
  159. self.assertEqual(feats[3].shape, torch.Size([1, 256, 2, 2]))
  160. # Test single scale output and model
  161. # without relative position bias
  162. extra['stage4']['multiscale_output'] = False
  163. extra['with_rpe'] = False
  164. model = HRFormer(extra=extra, norm_cfg=norm_cfg)
  165. model.init_weights()
  166. model.train()
  167. imgs = torch.randn(1, 3, 64, 64)
  168. feats = model(imgs)
  169. self.assertIsInstance(feats, tuple)
  170. self.assertEqual(len(feats), 1)
  171. self.assertEqual(feats[-1].shape, torch.Size([1, 32, 16, 16]))