test_litehrnet.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from torch.nn.modules.batchnorm import _BatchNorm
  5. from mmpose.models.backbones import LiteHRNet
  6. from mmpose.models.backbones.litehrnet import LiteHRModule
  7. from mmpose.models.backbones.resnet import Bottleneck
  8. class TestLiteHrnet(TestCase):
  9. @staticmethod
  10. def is_norm(modules):
  11. """Check if is one of the norms."""
  12. if isinstance(modules, (_BatchNorm, )):
  13. return True
  14. return False
  15. @staticmethod
  16. def all_zeros(modules):
  17. """Check if the weight(and bias) is all zero."""
  18. weight_zero = torch.equal(modules.weight.data,
  19. torch.zeros_like(modules.weight.data))
  20. if hasattr(modules, 'bias'):
  21. bias_zero = torch.equal(modules.bias.data,
  22. torch.zeros_like(modules.bias.data))
  23. else:
  24. bias_zero = True
  25. return weight_zero and bias_zero
  26. def test_litehrmodule(self):
  27. # Test LiteHRModule forward
  28. block = LiteHRModule(
  29. num_branches=1,
  30. num_blocks=1,
  31. in_channels=[
  32. 40,
  33. ],
  34. reduce_ratio=8,
  35. module_type='LITE')
  36. x = torch.randn(2, 40, 56, 56)
  37. x_out = block([[x]])
  38. self.assertEqual(x_out[0][0].shape, torch.Size([2, 40, 56, 56]))
  39. block = LiteHRModule(
  40. num_branches=1,
  41. num_blocks=1,
  42. in_channels=[
  43. 40,
  44. ],
  45. reduce_ratio=8,
  46. module_type='NAIVE')
  47. x = torch.randn(2, 40, 56, 56)
  48. x_out = block([x])
  49. self.assertEqual(x_out[0].shape, torch.Size([2, 40, 56, 56]))
  50. with self.assertRaises(ValueError):
  51. block = LiteHRModule(
  52. num_branches=1,
  53. num_blocks=1,
  54. in_channels=[
  55. 40,
  56. ],
  57. reduce_ratio=8,
  58. module_type='none')
  59. def test_litehrnet_backbone(self):
  60. extra = dict(
  61. stem=dict(stem_channels=32, out_channels=32, expand_ratio=1),
  62. num_stages=3,
  63. stages_spec=dict(
  64. num_modules=(2, 4, 2),
  65. num_branches=(2, 3, 4),
  66. num_blocks=(2, 2, 2),
  67. module_type=('LITE', 'LITE', 'LITE'),
  68. with_fuse=(True, True, True),
  69. reduce_ratios=(8, 8, 8),
  70. num_channels=(
  71. (40, 80),
  72. (40, 80, 160),
  73. (40, 80, 160, 320),
  74. )),
  75. with_head=True)
  76. model = LiteHRNet(extra, in_channels=3)
  77. imgs = torch.randn(2, 3, 224, 224)
  78. feat = model(imgs)
  79. self.assertIsInstance(feat, tuple)
  80. self.assertEqual(feat[-1].shape, torch.Size([2, 40, 56, 56]))
  81. # Test HRNet zero initialization of residual
  82. model = LiteHRNet(extra, in_channels=3)
  83. model.init_weights()
  84. for m in model.modules():
  85. if isinstance(m, Bottleneck):
  86. self.assertTrue(self.all_zeros(m.norm3))
  87. model.train()
  88. imgs = torch.randn(2, 3, 224, 224)
  89. feat = model(imgs)
  90. self.assertIsInstance(feat, tuple)
  91. self.assertEqual(feat[-1].shape, torch.Size([2, 40, 56, 56]))
  92. extra = dict(
  93. stem=dict(stem_channels=32, out_channels=32, expand_ratio=1),
  94. num_stages=3,
  95. stages_spec=dict(
  96. num_modules=(2, 4, 2),
  97. num_branches=(2, 3, 4),
  98. num_blocks=(2, 2, 2),
  99. module_type=('NAIVE', 'NAIVE', 'NAIVE'),
  100. with_fuse=(True, True, True),
  101. reduce_ratios=(8, 8, 8),
  102. num_channels=(
  103. (40, 80),
  104. (40, 80, 160),
  105. (40, 80, 160, 320),
  106. )),
  107. with_head=True)
  108. model = LiteHRNet(extra, in_channels=3)
  109. imgs = torch.randn(2, 3, 224, 224)
  110. feat = model(imgs)
  111. self.assertIsInstance(feat, tuple)
  112. self.assertEqual(feat[-1].shape, torch.Size([2, 40, 56, 56]))
  113. # Test HRNet zero initialization of residual
  114. model = LiteHRNet(extra, in_channels=3)
  115. model.init_weights()
  116. for m in model.modules():
  117. if isinstance(m, Bottleneck):
  118. self.assertTrue(self.all_zeros(m.norm3))
  119. model.train()
  120. imgs = torch.randn(2, 3, 224, 224)
  121. feat = model(imgs)
  122. self.assertIsInstance(feat, tuple)
  123. self.assertEqual(feat[-1].shape, torch.Size([2, 40, 56, 56]))