test_mobilenet_v2.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from torch.nn.modules import GroupNorm
  5. from torch.nn.modules.batchnorm import _BatchNorm
  6. from mmdet.models.backbones.mobilenet_v2 import MobileNetV2
  7. from .utils import check_norm_state, is_block, is_norm
  8. def test_mobilenetv2_backbone():
  9. with pytest.raises(ValueError):
  10. # frozen_stages must in range(-1, 8)
  11. MobileNetV2(frozen_stages=8)
  12. with pytest.raises(ValueError):
  13. # out_indices in range(-1, 8)
  14. MobileNetV2(out_indices=[8])
  15. # Test MobileNetV2 with first stage frozen
  16. frozen_stages = 1
  17. model = MobileNetV2(frozen_stages=frozen_stages)
  18. model.train()
  19. for mod in model.conv1.modules():
  20. for param in mod.parameters():
  21. assert param.requires_grad is False
  22. for i in range(1, frozen_stages + 1):
  23. layer = getattr(model, f'layer{i}')
  24. for mod in layer.modules():
  25. if isinstance(mod, _BatchNorm):
  26. assert mod.training is False
  27. for param in layer.parameters():
  28. assert param.requires_grad is False
  29. # Test MobileNetV2 with norm_eval=True
  30. model = MobileNetV2(norm_eval=True)
  31. model.train()
  32. assert check_norm_state(model.modules(), False)
  33. # Test MobileNetV2 forward with widen_factor=1.0
  34. model = MobileNetV2(widen_factor=1.0, out_indices=range(0, 8))
  35. model.train()
  36. assert check_norm_state(model.modules(), True)
  37. imgs = torch.randn(1, 3, 224, 224)
  38. feat = model(imgs)
  39. assert len(feat) == 8
  40. assert feat[0].shape == torch.Size((1, 16, 112, 112))
  41. assert feat[1].shape == torch.Size((1, 24, 56, 56))
  42. assert feat[2].shape == torch.Size((1, 32, 28, 28))
  43. assert feat[3].shape == torch.Size((1, 64, 14, 14))
  44. assert feat[4].shape == torch.Size((1, 96, 14, 14))
  45. assert feat[5].shape == torch.Size((1, 160, 7, 7))
  46. assert feat[6].shape == torch.Size((1, 320, 7, 7))
  47. assert feat[7].shape == torch.Size((1, 1280, 7, 7))
  48. # Test MobileNetV2 forward with widen_factor=0.5
  49. model = MobileNetV2(widen_factor=0.5, out_indices=range(0, 7))
  50. model.train()
  51. imgs = torch.randn(1, 3, 224, 224)
  52. feat = model(imgs)
  53. assert len(feat) == 7
  54. assert feat[0].shape == torch.Size((1, 8, 112, 112))
  55. assert feat[1].shape == torch.Size((1, 16, 56, 56))
  56. assert feat[2].shape == torch.Size((1, 16, 28, 28))
  57. assert feat[3].shape == torch.Size((1, 32, 14, 14))
  58. assert feat[4].shape == torch.Size((1, 48, 14, 14))
  59. assert feat[5].shape == torch.Size((1, 80, 7, 7))
  60. assert feat[6].shape == torch.Size((1, 160, 7, 7))
  61. # Test MobileNetV2 forward with widen_factor=2.0
  62. model = MobileNetV2(widen_factor=2.0, out_indices=range(0, 8))
  63. model.train()
  64. imgs = torch.randn(1, 3, 224, 224)
  65. feat = model(imgs)
  66. assert feat[0].shape == torch.Size((1, 32, 112, 112))
  67. assert feat[1].shape == torch.Size((1, 48, 56, 56))
  68. assert feat[2].shape == torch.Size((1, 64, 28, 28))
  69. assert feat[3].shape == torch.Size((1, 128, 14, 14))
  70. assert feat[4].shape == torch.Size((1, 192, 14, 14))
  71. assert feat[5].shape == torch.Size((1, 320, 7, 7))
  72. assert feat[6].shape == torch.Size((1, 640, 7, 7))
  73. assert feat[7].shape == torch.Size((1, 2560, 7, 7))
  74. # Test MobileNetV2 forward with dict(type='ReLU')
  75. model = MobileNetV2(
  76. widen_factor=1.0, act_cfg=dict(type='ReLU'), out_indices=range(0, 7))
  77. model.train()
  78. imgs = torch.randn(1, 3, 224, 224)
  79. feat = model(imgs)
  80. assert len(feat) == 7
  81. assert feat[0].shape == torch.Size((1, 16, 112, 112))
  82. assert feat[1].shape == torch.Size((1, 24, 56, 56))
  83. assert feat[2].shape == torch.Size((1, 32, 28, 28))
  84. assert feat[3].shape == torch.Size((1, 64, 14, 14))
  85. assert feat[4].shape == torch.Size((1, 96, 14, 14))
  86. assert feat[5].shape == torch.Size((1, 160, 7, 7))
  87. assert feat[6].shape == torch.Size((1, 320, 7, 7))
  88. # Test MobileNetV2 with BatchNorm forward
  89. model = MobileNetV2(widen_factor=1.0, out_indices=range(0, 7))
  90. for m in model.modules():
  91. if is_norm(m):
  92. assert isinstance(m, _BatchNorm)
  93. model.train()
  94. imgs = torch.randn(1, 3, 224, 224)
  95. feat = model(imgs)
  96. assert len(feat) == 7
  97. assert feat[0].shape == torch.Size((1, 16, 112, 112))
  98. assert feat[1].shape == torch.Size((1, 24, 56, 56))
  99. assert feat[2].shape == torch.Size((1, 32, 28, 28))
  100. assert feat[3].shape == torch.Size((1, 64, 14, 14))
  101. assert feat[4].shape == torch.Size((1, 96, 14, 14))
  102. assert feat[5].shape == torch.Size((1, 160, 7, 7))
  103. assert feat[6].shape == torch.Size((1, 320, 7, 7))
  104. # Test MobileNetV2 with GroupNorm forward
  105. model = MobileNetV2(
  106. widen_factor=1.0,
  107. norm_cfg=dict(type='GN', num_groups=2, requires_grad=True),
  108. out_indices=range(0, 7))
  109. for m in model.modules():
  110. if is_norm(m):
  111. assert isinstance(m, GroupNorm)
  112. model.train()
  113. imgs = torch.randn(1, 3, 224, 224)
  114. feat = model(imgs)
  115. assert len(feat) == 7
  116. assert feat[0].shape == torch.Size((1, 16, 112, 112))
  117. assert feat[1].shape == torch.Size((1, 24, 56, 56))
  118. assert feat[2].shape == torch.Size((1, 32, 28, 28))
  119. assert feat[3].shape == torch.Size((1, 64, 14, 14))
  120. assert feat[4].shape == torch.Size((1, 96, 14, 14))
  121. assert feat[5].shape == torch.Size((1, 160, 7, 7))
  122. assert feat[6].shape == torch.Size((1, 320, 7, 7))
  123. # Test MobileNetV2 with layers 1, 3, 5 out forward
  124. model = MobileNetV2(widen_factor=1.0, out_indices=(0, 2, 4))
  125. model.train()
  126. imgs = torch.randn(1, 3, 224, 224)
  127. feat = model(imgs)
  128. assert len(feat) == 3
  129. assert feat[0].shape == torch.Size((1, 16, 112, 112))
  130. assert feat[1].shape == torch.Size((1, 32, 28, 28))
  131. assert feat[2].shape == torch.Size((1, 96, 14, 14))
  132. # Test MobileNetV2 with checkpoint forward
  133. model = MobileNetV2(
  134. widen_factor=1.0, with_cp=True, out_indices=range(0, 7))
  135. for m in model.modules():
  136. if is_block(m):
  137. assert m.with_cp
  138. model.train()
  139. imgs = torch.randn(1, 3, 224, 224)
  140. feat = model(imgs)
  141. assert len(feat) == 7
  142. assert feat[0].shape == torch.Size((1, 16, 112, 112))
  143. assert feat[1].shape == torch.Size((1, 24, 56, 56))
  144. assert feat[2].shape == torch.Size((1, 32, 28, 28))
  145. assert feat[3].shape == torch.Size((1, 64, 14, 14))
  146. assert feat[4].shape == torch.Size((1, 96, 14, 14))
  147. assert feat[5].shape == torch.Size((1, 160, 7, 7))
  148. assert feat[6].shape == torch.Size((1, 320, 7, 7))