test_vgg.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
  5. from mmpose.models.backbones import VGG
  6. class TestVGG(TestCase):
  7. @staticmethod
  8. def check_norm_state(modules, train_state):
  9. """Check if norm layer is in correct train state."""
  10. for mod in modules:
  11. if isinstance(mod, _BatchNorm):
  12. if mod.training != train_state:
  13. return False
  14. return True
  15. def test_vgg(self):
  16. """Test VGG backbone."""
  17. with self.assertRaises(KeyError):
  18. # VGG depth should be in [11, 13, 16, 19]
  19. VGG(18)
  20. with self.assertRaises(AssertionError):
  21. # In VGG: 1 <= num_stages <= 5
  22. VGG(11, num_stages=0)
  23. with self.assertRaises(AssertionError):
  24. # In VGG: 1 <= num_stages <= 5
  25. VGG(11, num_stages=6)
  26. with self.assertRaises(AssertionError):
  27. # len(dilations) == num_stages
  28. VGG(11, dilations=(1, 1), num_stages=3)
  29. # Test VGG11 norm_eval=True
  30. model = VGG(11, norm_eval=True)
  31. model.init_weights()
  32. model.train()
  33. self.assertTrue(self.check_norm_state(model.modules(), False))
  34. # Test VGG11 forward without classifiers
  35. model = VGG(11, out_indices=(0, 1, 2, 3, 4))
  36. model.init_weights()
  37. model.train()
  38. imgs = torch.randn(1, 3, 224, 224)
  39. feat = model(imgs)
  40. self.assertEqual(len(feat), 5)
  41. self.assertEqual(feat[0].shape, (1, 64, 112, 112))
  42. self.assertEqual(feat[1].shape, (1, 128, 56, 56))
  43. self.assertEqual(feat[2].shape, (1, 256, 28, 28))
  44. self.assertEqual(feat[3].shape, (1, 512, 14, 14))
  45. self.assertEqual(feat[4].shape, (1, 512, 7, 7))
  46. # Test VGG11 forward with classifiers
  47. model = VGG(11, num_classes=10, out_indices=(0, 1, 2, 3, 4, 5))
  48. model.init_weights()
  49. model.train()
  50. imgs = torch.randn(1, 3, 224, 224)
  51. feat = model(imgs)
  52. self.assertEqual(len(feat), 6)
  53. self.assertEqual(feat[0].shape, (1, 64, 112, 112))
  54. self.assertEqual(feat[1].shape, (1, 128, 56, 56))
  55. self.assertEqual(feat[2].shape, (1, 256, 28, 28))
  56. self.assertEqual(feat[3].shape, (1, 512, 14, 14))
  57. self.assertEqual(feat[4].shape, (1, 512, 7, 7))
  58. self.assertEqual(feat[5].shape, (1, 10))
  59. # Test VGG11BN forward
  60. model = VGG(11, norm_cfg=dict(type='BN'), out_indices=(0, 1, 2, 3, 4))
  61. model.init_weights()
  62. model.train()
  63. imgs = torch.randn(1, 3, 224, 224)
  64. feat = model(imgs)
  65. self.assertEqual(len(feat), 5)
  66. self.assertEqual(feat[0].shape, (1, 64, 112, 112))
  67. self.assertEqual(feat[1].shape, (1, 128, 56, 56))
  68. self.assertEqual(feat[2].shape, (1, 256, 28, 28))
  69. self.assertEqual(feat[3].shape, (1, 512, 14, 14))
  70. self.assertEqual(feat[4].shape, (1, 512, 7, 7))
  71. # Test VGG11BN forward with classifiers
  72. model = VGG(
  73. 11,
  74. num_classes=10,
  75. norm_cfg=dict(type='BN'),
  76. out_indices=(0, 1, 2, 3, 4, 5))
  77. model.init_weights()
  78. model.train()
  79. imgs = torch.randn(1, 3, 224, 224)
  80. feat = model(imgs)
  81. self.assertEqual(len(feat), 6)
  82. self.assertEqual(feat[0].shape, (1, 64, 112, 112))
  83. self.assertEqual(feat[1].shape, (1, 128, 56, 56))
  84. self.assertEqual(feat[2].shape, (1, 256, 28, 28))
  85. self.assertEqual(feat[3].shape, (1, 512, 14, 14))
  86. self.assertEqual(feat[4].shape, (1, 512, 7, 7))
  87. self.assertEqual(feat[5].shape, (1, 10))
  88. # Test VGG13 with layers 1, 2, 3 out forward
  89. model = VGG(13, out_indices=(0, 1, 2))
  90. model.init_weights()
  91. model.train()
  92. imgs = torch.randn(1, 3, 224, 224)
  93. feat = model(imgs)
  94. self.assertEqual(len(feat), 3)
  95. self.assertEqual(feat[0].shape, (1, 64, 112, 112))
  96. self.assertEqual(feat[1].shape, (1, 128, 56, 56))
  97. self.assertEqual(feat[2].shape, (1, 256, 28, 28))
  98. # Test VGG16 with top feature maps out forward
  99. model = VGG(16)
  100. model.init_weights()
  101. model.train()
  102. imgs = torch.randn(1, 3, 224, 224)
  103. feat = model(imgs)
  104. self.assertEqual(len(feat), 1)
  105. self.assertEqual(feat[-1].shape, (1, 512, 7, 7))
  106. # Test VGG19 with classification score out forward
  107. model = VGG(19, num_classes=10)
  108. model.init_weights()
  109. model.train()
  110. imgs = torch.randn(1, 3, 224, 224)
  111. feat = model(imgs)
  112. self.assertEqual(len(feat), 1)
  113. self.assertEqual(feat[-1].shape, (1, 10))