test_regnet.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmpose.models.backbones import RegNet
  5. class TestRegnet(TestCase):
  6. regnet_test_data = [
  7. ('regnetx_400mf',
  8. dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22,
  9. bot_mul=1.0), [32, 64, 160, 384]),
  10. ('regnetx_800mf',
  11. dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16,
  12. bot_mul=1.0), [64, 128, 288, 672]),
  13. ('regnetx_1.6gf',
  14. dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18,
  15. bot_mul=1.0), [72, 168, 408, 912]),
  16. ('regnetx_3.2gf',
  17. dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25,
  18. bot_mul=1.0), [96, 192, 432, 1008]),
  19. ('regnetx_4.0gf',
  20. dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23,
  21. bot_mul=1.0), [80, 240, 560, 1360]),
  22. ('regnetx_6.4gf',
  23. dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17,
  24. bot_mul=1.0), [168, 392, 784, 1624]),
  25. ('regnetx_8.0gf',
  26. dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23,
  27. bot_mul=1.0), [80, 240, 720, 1920]),
  28. ('regnetx_12gf',
  29. dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19,
  30. bot_mul=1.0), [224, 448, 896, 2240]),
  31. ]
  32. def _test_regnet_backbone(self, arch_name, arch, out_channels):
  33. with self.assertRaises(AssertionError):
  34. # ResNeXt depth should be in [50, 101, 152]
  35. RegNet(arch_name + '233')
  36. # output the last feature map
  37. model = RegNet(arch_name)
  38. model.init_weights()
  39. model.train()
  40. imgs = torch.randn(1, 3, 224, 224)
  41. feat = model(imgs)
  42. self.assertIsInstance(feat, tuple)
  43. self.assertEqual(feat[-1].shape, (1, out_channels[-1], 7, 7))
  44. # output feature map of all stages
  45. model = RegNet(arch_name, out_indices=(0, 1, 2, 3))
  46. model.init_weights()
  47. model.train()
  48. imgs = torch.randn(1, 3, 224, 224)
  49. feat = model(imgs)
  50. self.assertEqual(len(feat), 4)
  51. self.assertEqual(feat[0].shape, (1, out_channels[0], 56, 56))
  52. self.assertEqual(feat[1].shape, (1, out_channels[1], 28, 28))
  53. self.assertEqual(feat[2].shape, (1, out_channels[2], 14, 14))
  54. self.assertEqual(feat[3].shape, (1, out_channels[3], 7, 7))
  55. def test_regnet_backbone(self):
  56. for arch_name, arch, out_channels in self.regnet_test_data:
  57. with self.subTest(
  58. arch_name=arch_name, arch=arch, out_channels=out_channels):
  59. self._test_regnet_backbone(arch_name, arch, out_channels)
  60. def _test_custom_arch(self, arch_name, arch, out_channels):
  61. # output the last feature map
  62. model = RegNet(arch)
  63. model.init_weights()
  64. imgs = torch.randn(1, 3, 224, 224)
  65. feat = model(imgs)
  66. self.assertIsInstance(feat, tuple)
  67. self.assertEqual(feat[-1].shape, (1, out_channels[-1], 7, 7))
  68. # output feature map of all stages
  69. model = RegNet(arch, out_indices=(0, 1, 2, 3))
  70. model.init_weights()
  71. imgs = torch.randn(1, 3, 224, 224)
  72. feat = model(imgs)
  73. self.assertEqual(len(feat), 4)
  74. self.assertEqual(feat[0].shape, (1, out_channels[0], 56, 56))
  75. self.assertEqual(feat[1].shape, (1, out_channels[1], 28, 28))
  76. self.assertEqual(feat[2].shape, (1, out_channels[2], 14, 14))
  77. self.assertEqual(feat[3].shape, (1, out_channels[3], 7, 7))
  78. def test_custom_arch(self):
  79. for arch_name, arch, out_channels in self.regnet_test_data:
  80. with self.subTest(
  81. arch_name=arch_name, arch=arch, out_channels=out_channels):
  82. self._test_custom_arch(arch_name, arch, out_channels)
  83. def test_exception(self):
  84. # arch must be a str or dict
  85. with self.assertRaises(TypeError):
  86. _ = RegNet(50)