test_mobilenet_v2.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from torch.nn.modules import GroupNorm
  5. from torch.nn.modules.batchnorm import _BatchNorm
  6. from mmpose.models.backbones import MobileNetV2
  7. from mmpose.models.backbones.mobilenet_v2 import InvertedResidual
  8. class TestMobilenetV2(TestCase):
  9. @staticmethod
  10. def is_block(modules):
  11. """Check if is ResNet building block."""
  12. if isinstance(modules, (InvertedResidual, )):
  13. return True
  14. return False
  15. @staticmethod
  16. def is_norm(modules):
  17. """Check if is one of the norms."""
  18. if isinstance(modules, (GroupNorm, _BatchNorm)):
  19. return True
  20. return False
  21. @staticmethod
  22. def check_norm_state(modules, train_state):
  23. """Check if norm layer is in correct train state."""
  24. for mod in modules:
  25. if isinstance(mod, _BatchNorm):
  26. if mod.training != train_state:
  27. return False
  28. return True
  29. def test_mobilenetv2_invertedresidual(self):
  30. with self.assertRaises(AssertionError):
  31. # stride must be in [1, 2]
  32. InvertedResidual(16, 24, stride=3, expand_ratio=6)
  33. # Test InvertedResidual with checkpoint forward, stride=1
  34. block = InvertedResidual(16, 24, stride=1, expand_ratio=6)
  35. x = torch.randn(1, 16, 56, 56)
  36. x_out = block(x)
  37. self.assertEqual(x_out.shape, torch.Size((1, 24, 56, 56)))
  38. # Test InvertedResidual with expand_ratio=1
  39. block = InvertedResidual(16, 16, stride=1, expand_ratio=1)
  40. self.assertEqual(len(block.conv), 2)
  41. # Test InvertedResidual with use_res_connect
  42. block = InvertedResidual(16, 16, stride=1, expand_ratio=6)
  43. x = torch.randn(1, 16, 56, 56)
  44. x_out = block(x)
  45. self.assertTrue(block.use_res_connect)
  46. self.assertEqual(x_out.shape, torch.Size((1, 16, 56, 56)))
  47. # Test InvertedResidual with checkpoint forward, stride=2
  48. block = InvertedResidual(16, 24, stride=2, expand_ratio=6)
  49. x = torch.randn(1, 16, 56, 56)
  50. x_out = block(x)
  51. self.assertEqual(x_out.shape, torch.Size((1, 24, 28, 28)))
  52. # Test InvertedResidual with checkpoint forward
  53. block = InvertedResidual(
  54. 16, 24, stride=1, expand_ratio=6, with_cp=True)
  55. self.assertTrue(block.with_cp)
  56. x = torch.randn(1, 16, 56, 56)
  57. x_out = block(x)
  58. self.assertEqual(x_out.shape, torch.Size((1, 24, 56, 56)))
  59. # Test InvertedResidual with act_cfg=dict(type='ReLU')
  60. block = InvertedResidual(
  61. 16, 24, stride=1, expand_ratio=6, act_cfg=dict(type='ReLU'))
  62. x = torch.randn(1, 16, 56, 56)
  63. x_out = block(x)
  64. self.assertEqual(x_out.shape, torch.Size((1, 24, 56, 56)))
  65. def test_mobilenetv2_backbone(self):
  66. with self.assertRaises(TypeError):
  67. # pretrained must be a string path
  68. model = MobileNetV2()
  69. model.init_weights(pretrained=0)
  70. with self.assertRaises(ValueError):
  71. # frozen_stages must in range(1, 8)
  72. MobileNetV2(frozen_stages=8)
  73. with self.assertRaises(ValueError):
  74. # tout_indices in range(-1, 8)
  75. MobileNetV2(out_indices=[8])
  76. # Test MobileNetV2 with first stage frozen
  77. frozen_stages = 1
  78. model = MobileNetV2(frozen_stages=frozen_stages)
  79. model.init_weights()
  80. model.train()
  81. for mod in model.conv1.modules():
  82. for param in mod.parameters():
  83. self.assertFalse(param.requires_grad)
  84. for i in range(1, frozen_stages + 1):
  85. layer = getattr(model, f'layer{i}')
  86. for mod in layer.modules():
  87. if isinstance(mod, _BatchNorm):
  88. self.assertFalse(mod.training)
  89. for param in layer.parameters():
  90. self.assertFalse(param.requires_grad)
  91. # Test MobileNetV2 with norm_eval=True
  92. model = MobileNetV2(norm_eval=True)
  93. model.init_weights()
  94. model.train()
  95. self.assertTrue(self.check_norm_state(model.modules(), False))
  96. # Test MobileNetV2 forward with widen_factor=1.0
  97. model = MobileNetV2(widen_factor=1.0, out_indices=range(0, 8))
  98. model.init_weights()
  99. model.train()
  100. self.assertTrue(self.check_norm_state(model.modules(), True))
  101. imgs = torch.randn(1, 3, 224, 224)
  102. feat = model(imgs)
  103. self.assertEqual(len(feat), 8)
  104. self.assertEqual(feat[0].shape, torch.Size((1, 16, 112, 112)))
  105. self.assertEqual(feat[1].shape, torch.Size((1, 24, 56, 56)))
  106. self.assertEqual(feat[2].shape, torch.Size((1, 32, 28, 28)))
  107. self.assertEqual(feat[3].shape, torch.Size((1, 64, 14, 14)))
  108. self.assertEqual(feat[4].shape, torch.Size((1, 96, 14, 14)))
  109. self.assertEqual(feat[5].shape, torch.Size((1, 160, 7, 7)))
  110. self.assertEqual(feat[6].shape, torch.Size((1, 320, 7, 7)))
  111. self.assertEqual(feat[7].shape, torch.Size((1, 1280, 7, 7)))
  112. # Test MobileNetV2 forward with widen_factor=0.5
  113. model = MobileNetV2(widen_factor=0.5, out_indices=range(0, 7))
  114. model.init_weights()
  115. model.train()
  116. imgs = torch.randn(1, 3, 224, 224)
  117. feat = model(imgs)
  118. self.assertEqual(len(feat), 7)
  119. self.assertEqual(feat[0].shape, torch.Size((1, 8, 112, 112)))
  120. self.assertEqual(feat[1].shape, torch.Size((1, 16, 56, 56)))
  121. self.assertEqual(feat[2].shape, torch.Size((1, 16, 28, 28)))
  122. self.assertEqual(feat[3].shape, torch.Size((1, 32, 14, 14)))
  123. self.assertEqual(feat[4].shape, torch.Size((1, 48, 14, 14)))
  124. self.assertEqual(feat[5].shape, torch.Size((1, 80, 7, 7)))
  125. self.assertEqual(feat[6].shape, torch.Size((1, 160, 7, 7)))
  126. # Test MobileNetV2 forward with widen_factor=2.0
  127. model = MobileNetV2(widen_factor=2.0)
  128. model.init_weights()
  129. model.train()
  130. imgs = torch.randn(1, 3, 224, 224)
  131. feat = model(imgs)
  132. self.assertIsInstance(feat, tuple)
  133. self.assertEqual(feat[-1].shape, torch.Size((1, 2560, 7, 7)))
  134. # Test MobileNetV2 forward with out_indices=None
  135. model = MobileNetV2(widen_factor=1.0)
  136. model.init_weights()
  137. model.train()
  138. imgs = torch.randn(1, 3, 224, 224)
  139. feat = model(imgs)
  140. self.assertIsInstance(feat, tuple)
  141. self.assertEqual(feat[-1].shape, torch.Size((1, 1280, 7, 7)))
  142. # Test MobileNetV2 forward with dict(type='ReLU')
  143. model = MobileNetV2(
  144. widen_factor=1.0,
  145. act_cfg=dict(type='ReLU'),
  146. out_indices=range(0, 7))
  147. model.init_weights()
  148. model.train()
  149. imgs = torch.randn(1, 3, 224, 224)
  150. feat = model(imgs)
  151. self.assertEqual(len(feat), 7)
  152. self.assertEqual(feat[0].shape, torch.Size((1, 16, 112, 112)))
  153. self.assertEqual(feat[1].shape, torch.Size((1, 24, 56, 56)))
  154. self.assertEqual(feat[2].shape, torch.Size((1, 32, 28, 28)))
  155. self.assertEqual(feat[3].shape, torch.Size((1, 64, 14, 14)))
  156. self.assertEqual(feat[4].shape, torch.Size((1, 96, 14, 14)))
  157. self.assertEqual(feat[5].shape, torch.Size((1, 160, 7, 7)))
  158. self.assertEqual(feat[6].shape, torch.Size((1, 320, 7, 7)))
  159. # Test MobileNetV2 with GroupNorm forward
  160. model = MobileNetV2(widen_factor=1.0, out_indices=range(0, 7))
  161. for m in model.modules():
  162. if self.is_norm(m):
  163. self.assertIsInstance(m, _BatchNorm)
  164. model.init_weights()
  165. model.train()
  166. imgs = torch.randn(1, 3, 224, 224)
  167. feat = model(imgs)
  168. self.assertEqual(len(feat), 7)
  169. self.assertEqual(feat[0].shape, torch.Size((1, 16, 112, 112)))
  170. self.assertEqual(feat[1].shape, torch.Size((1, 24, 56, 56)))
  171. self.assertEqual(feat[2].shape, torch.Size((1, 32, 28, 28)))
  172. self.assertEqual(feat[3].shape, torch.Size((1, 64, 14, 14)))
  173. self.assertEqual(feat[4].shape, torch.Size((1, 96, 14, 14)))
  174. self.assertEqual(feat[5].shape, torch.Size((1, 160, 7, 7)))
  175. self.assertEqual(feat[6].shape, torch.Size((1, 320, 7, 7)))
  176. # Test MobileNetV2 with BatchNorm forward
  177. model = MobileNetV2(
  178. widen_factor=1.0,
  179. norm_cfg=dict(type='GN', num_groups=2, requires_grad=True),
  180. out_indices=range(0, 7))
  181. for m in model.modules():
  182. if self.is_norm(m):
  183. self.assertIsInstance(m, GroupNorm)
  184. model.init_weights()
  185. model.train()
  186. imgs = torch.randn(1, 3, 224, 224)
  187. feat = model(imgs)
  188. self.assertEqual(len(feat), 7)
  189. self.assertEqual(feat[0].shape, torch.Size((1, 16, 112, 112)))
  190. self.assertEqual(feat[1].shape, torch.Size((1, 24, 56, 56)))
  191. self.assertEqual(feat[2].shape, torch.Size((1, 32, 28, 28)))
  192. self.assertEqual(feat[3].shape, torch.Size((1, 64, 14, 14)))
  193. self.assertEqual(feat[4].shape, torch.Size((1, 96, 14, 14)))
  194. self.assertEqual(feat[5].shape, torch.Size((1, 160, 7, 7)))
  195. self.assertEqual(feat[6].shape, torch.Size((1, 320, 7, 7)))
  196. # Test MobileNetV2 with layers 1, 3, 5 out forward
  197. model = MobileNetV2(widen_factor=1.0, out_indices=(0, 2, 4))
  198. model.init_weights()
  199. model.train()
  200. imgs = torch.randn(1, 3, 224, 224)
  201. feat = model(imgs)
  202. self.assertEqual(len(feat), 3)
  203. self.assertEqual(feat[0].shape, torch.Size((1, 16, 112, 112)))
  204. self.assertEqual(feat[1].shape, torch.Size((1, 32, 28, 28)))
  205. self.assertEqual(feat[2].shape, torch.Size((1, 96, 14, 14)))
  206. # Test MobileNetV2 with checkpoint forward
  207. model = MobileNetV2(
  208. widen_factor=1.0, with_cp=True, out_indices=range(0, 7))
  209. for m in model.modules():
  210. if self.is_block(m):
  211. self.assertTrue(m.with_cp)
  212. model.init_weights()
  213. model.train()
  214. imgs = torch.randn(1, 3, 224, 224)
  215. feat = model(imgs)
  216. self.assertEqual(len(feat), 7)
  217. self.assertEqual(feat[0].shape, torch.Size((1, 16, 112, 112)))
  218. self.assertEqual(feat[1].shape, torch.Size((1, 24, 56, 56)))
  219. self.assertEqual(feat[2].shape, torch.Size((1, 32, 28, 28)))
  220. self.assertEqual(feat[3].shape, torch.Size((1, 64, 14, 14)))
  221. self.assertEqual(feat[4].shape, torch.Size((1, 96, 14, 14)))
  222. self.assertEqual(feat[5].shape, torch.Size((1, 160, 7, 7)))
  223. self.assertEqual(feat[6].shape, torch.Size((1, 320, 7, 7)))