test_hrnet.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from torch.nn.modules.batchnorm import _BatchNorm
  5. from mmpose.models.backbones import HRNet
  6. from mmpose.models.backbones.hrnet import HRModule
  7. from mmpose.models.backbones.resnet import BasicBlock, Bottleneck
  8. class TestHrnet(TestCase):
  9. @staticmethod
  10. def is_block(modules):
  11. """Check if is HRModule building block."""
  12. if isinstance(modules, (HRModule, )):
  13. return True
  14. return False
  15. @staticmethod
  16. def is_norm(modules):
  17. """Check if is one of the norms."""
  18. if isinstance(modules, (_BatchNorm, )):
  19. return True
  20. return False
  21. @staticmethod
  22. def all_zeros(modules):
  23. """Check if the weight(and bias) is all zero."""
  24. weight_zero = torch.equal(modules.weight.data,
  25. torch.zeros_like(modules.weight.data))
  26. if hasattr(modules, 'bias'):
  27. bias_zero = torch.equal(modules.bias.data,
  28. torch.zeros_like(modules.bias.data))
  29. else:
  30. bias_zero = True
  31. return weight_zero and bias_zero
  32. def test_hrmodule(self):
  33. # Test HRModule forward
  34. block = HRModule(
  35. num_branches=1,
  36. blocks=BasicBlock,
  37. num_blocks=(4, ),
  38. in_channels=[
  39. 64,
  40. ],
  41. num_channels=(64, ))
  42. x = torch.randn(2, 64, 56, 56)
  43. x_out = block([x])
  44. self.assertEqual(x_out[0].shape, torch.Size([2, 64, 56, 56]))
  45. def test_hrnet_backbone(self):
  46. extra = dict(
  47. stage1=dict(
  48. num_modules=1,
  49. num_branches=1,
  50. block='BOTTLENECK',
  51. num_blocks=(4, ),
  52. num_channels=(64, )),
  53. stage2=dict(
  54. num_modules=1,
  55. num_branches=2,
  56. block='BASIC',
  57. num_blocks=(4, 4),
  58. num_channels=(32, 64)),
  59. stage3=dict(
  60. num_modules=4,
  61. num_branches=3,
  62. block='BASIC',
  63. num_blocks=(4, 4, 4),
  64. num_channels=(32, 64, 128)),
  65. stage4=dict(
  66. num_modules=3,
  67. num_branches=4,
  68. block='BASIC',
  69. num_blocks=(4, 4, 4, 4),
  70. num_channels=(32, 64, 128, 256)))
  71. model = HRNet(extra, in_channels=3)
  72. imgs = torch.randn(2, 3, 224, 224)
  73. feat = model(imgs)
  74. self.assertIsInstance(feat, tuple)
  75. self.assertEqual(feat[-1].shape, torch.Size([2, 32, 56, 56]))
  76. # Test HRNet zero initialization of residual
  77. model = HRNet(extra, in_channels=3, zero_init_residual=True)
  78. model.init_weights()
  79. for m in model.modules():
  80. if isinstance(m, Bottleneck):
  81. self.assertTrue(self.all_zeros(m.norm3))
  82. model.train()
  83. imgs = torch.randn(2, 3, 224, 224)
  84. feat = model(imgs)
  85. self.assertIsInstance(feat, tuple)
  86. self.assertEqual(feat[-1].shape, torch.Size([2, 32, 56, 56]))
  87. # Test HRNet with the first three stages frozen
  88. frozen_stages = 3
  89. model = HRNet(extra, in_channels=3, frozen_stages=frozen_stages)
  90. model.init_weights()
  91. model.train()
  92. if frozen_stages >= 0:
  93. self.assertFalse(model.norm1.training)
  94. self.assertFalse(model.norm2.training)
  95. for layer in [model.conv1, model.norm1, model.conv2, model.norm2]:
  96. for param in layer.parameters():
  97. self.assertFalse(param.requires_grad)
  98. for i in range(1, frozen_stages + 1):
  99. if i == 1:
  100. layer = getattr(model, 'layer1')
  101. else:
  102. layer = getattr(model, f'stage{i}')
  103. for mod in layer.modules():
  104. if isinstance(mod, _BatchNorm):
  105. self.assertFalse(mod.training)
  106. for param in layer.parameters():
  107. self.assertFalse(param.requires_grad)
  108. if i < 4:
  109. layer = getattr(model, f'transition{i}')
  110. for mod in layer.modules():
  111. if isinstance(mod, _BatchNorm):
  112. self.assertFalse(mod.training)
  113. for param in layer.parameters():
  114. self.assertFalse(param.requires_grad)