test_seresnet.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from torch.nn.modules import AvgPool2d
  5. from torch.nn.modules.batchnorm import _BatchNorm
  6. from mmpose.models.backbones import SEResNet
  7. from mmpose.models.backbones.resnet import ResLayer
  8. from mmpose.models.backbones.seresnet import SEBottleneck, SELayer
  9. class TestSEResnet(TestCase):
  10. @staticmethod
  11. def all_zeros(modules):
  12. """Check if the weight(and bias) is all zero."""
  13. weight_zero = torch.equal(modules.weight.data,
  14. torch.zeros_like(modules.weight.data))
  15. if hasattr(modules, 'bias'):
  16. bias_zero = torch.equal(modules.bias.data,
  17. torch.zeros_like(modules.bias.data))
  18. else:
  19. bias_zero = True
  20. return weight_zero and bias_zero
  21. @staticmethod
  22. def check_norm_state(modules, train_state):
  23. """Check if norm layer is in correct train state."""
  24. for mod in modules:
  25. if isinstance(mod, _BatchNorm):
  26. if mod.training != train_state:
  27. return False
  28. return True
  29. def test_selayer(self):
  30. # Test selayer forward
  31. layer = SELayer(64)
  32. x = torch.randn(1, 64, 56, 56)
  33. x_out = layer(x)
  34. self.assertEqual(x_out.shape, torch.Size([1, 64, 56, 56]))
  35. # Test selayer forward with different ratio
  36. layer = SELayer(64, ratio=8)
  37. x = torch.randn(1, 64, 56, 56)
  38. x_out = layer(x)
  39. self.assertEqual(x_out.shape, torch.Size([1, 64, 56, 56]))
  40. def test_bottleneck(self):
  41. with self.assertRaises(AssertionError):
  42. # Style must be in ['pytorch', 'caffe']
  43. SEBottleneck(64, 64, style='tensorflow')
  44. # Test SEBottleneck with checkpoint forward
  45. block = SEBottleneck(64, 64, with_cp=True)
  46. self.assertTrue(block.with_cp)
  47. x = torch.randn(1, 64, 56, 56)
  48. x_out = block(x)
  49. self.assertEqual(x_out.shape, torch.Size([1, 64, 56, 56]))
  50. # Test Bottleneck style
  51. block = SEBottleneck(64, 256, stride=2, style='pytorch')
  52. self.assertEqual(block.conv1.stride, (1, 1))
  53. self.assertEqual(block.conv2.stride, (2, 2))
  54. block = SEBottleneck(64, 256, stride=2, style='caffe')
  55. self.assertEqual(block.conv1.stride, (2, 2))
  56. self.assertEqual(block.conv2.stride, (1, 1))
  57. # Test Bottleneck forward
  58. block = SEBottleneck(64, 64)
  59. x = torch.randn(1, 64, 56, 56)
  60. x_out = block(x)
  61. self.assertEqual(x_out.shape, torch.Size([1, 64, 56, 56]))
  62. def test_res_layer(self):
  63. # Test ResLayer of 3 Bottleneck w\o downsample
  64. layer = ResLayer(SEBottleneck, 3, 64, 64, se_ratio=16)
  65. self.assertEqual(len(layer), 3)
  66. self.assertEqual(layer[0].conv1.in_channels, 64)
  67. self.assertEqual(layer[0].conv1.out_channels, 16)
  68. for i in range(1, len(layer)):
  69. self.assertEqual(layer[i].conv1.in_channels, 64)
  70. self.assertEqual(layer[i].conv1.out_channels, 16)
  71. for i in range(len(layer)):
  72. self.assertIsNone(layer[i].downsample)
  73. x = torch.randn(1, 64, 56, 56)
  74. x_out = layer(x)
  75. self.assertEqual(x_out.shape, torch.Size([1, 64, 56, 56]))
  76. # Test ResLayer of 3 SEBottleneck with downsample
  77. layer = ResLayer(SEBottleneck, 3, 64, 256, se_ratio=16)
  78. self.assertEqual(layer[0].downsample[0].out_channels, 256)
  79. for i in range(1, len(layer)):
  80. self.assertIsNone(layer[i].downsample)
  81. x = torch.randn(1, 64, 56, 56)
  82. x_out = layer(x)
  83. self.assertEqual(x_out.shape, torch.Size([1, 256, 56, 56]))
  84. # Test ResLayer of 3 SEBottleneck with stride=2
  85. layer = ResLayer(SEBottleneck, 3, 64, 256, stride=2, se_ratio=8)
  86. self.assertEqual(layer[0].downsample[0].out_channels, 256)
  87. self.assertEqual(layer[0].downsample[0].stride, (2, 2))
  88. for i in range(1, len(layer)):
  89. self.assertIsNone(layer[i].downsample)
  90. x = torch.randn(1, 64, 56, 56)
  91. x_out = layer(x)
  92. self.assertEqual(x_out.shape, torch.Size([1, 256, 28, 28]))
  93. # Test ResLayer of 3 SEBottleneck with stride=2 and average downsample
  94. layer = ResLayer(
  95. SEBottleneck, 3, 64, 256, stride=2, avg_down=True, se_ratio=8)
  96. self.assertIsInstance(layer[0].downsample[0], AvgPool2d)
  97. self.assertEqual(layer[0].downsample[1].out_channels, 256)
  98. self.assertEqual(layer[0].downsample[1].stride, (1, 1))
  99. for i in range(1, len(layer)):
  100. self.assertIsNone(layer[i].downsample)
  101. x = torch.randn(1, 64, 56, 56)
  102. x_out = layer(x)
  103. self.assertEqual(x_out.shape, torch.Size([1, 256, 28, 28]))
  104. def test_seresnet(self):
  105. """Test resnet backbone."""
  106. with self.assertRaises(KeyError):
  107. # SEResNet depth should be in [50, 101, 152]
  108. SEResNet(20)
  109. with self.assertRaises(AssertionError):
  110. # In SEResNet: 1 <= num_stages <= 4
  111. SEResNet(50, num_stages=0)
  112. with self.assertRaises(AssertionError):
  113. # In SEResNet: 1 <= num_stages <= 4
  114. SEResNet(50, num_stages=5)
  115. with self.assertRaises(AssertionError):
  116. # len(strides) == len(dilations) == num_stages
  117. SEResNet(50, strides=(1, ), dilations=(1, 1), num_stages=3)
  118. with self.assertRaises(AssertionError):
  119. # Style must be in ['pytorch', 'caffe']
  120. SEResNet(50, style='tensorflow')
  121. # Test SEResNet50 norm_eval=True
  122. model = SEResNet(50, norm_eval=True)
  123. model.init_weights()
  124. model.train()
  125. self.assertTrue(self.check_norm_state(model.modules(), False))
  126. # Test SEResNet50 with torchvision pretrained weight
  127. init_cfg = dict(type='Pretrained', checkpoint='torchvision://resnet50')
  128. model = SEResNet(depth=50, norm_eval=True, init_cfg=init_cfg)
  129. model.train()
  130. self.assertTrue(self.check_norm_state(model.modules(), False))
  131. # Test SEResNet50 with first stage frozen
  132. frozen_stages = 1
  133. model = SEResNet(50, frozen_stages=frozen_stages)
  134. model.init_weights()
  135. model.train()
  136. self.assertFalse(model.norm1.training)
  137. for layer in [model.conv1, model.norm1]:
  138. for param in layer.parameters():
  139. self.assertFalse(param.requires_grad)
  140. for i in range(1, frozen_stages + 1):
  141. layer = getattr(model, f'layer{i}')
  142. for mod in layer.modules():
  143. if isinstance(mod, _BatchNorm):
  144. self.assertFalse(mod.training)
  145. for param in layer.parameters():
  146. self.assertFalse(param.requires_grad)
  147. # Test SEResNet50 with BatchNorm forward
  148. model = SEResNet(50, out_indices=(0, 1, 2, 3))
  149. model.init_weights()
  150. model.train()
  151. imgs = torch.randn(1, 3, 224, 224)
  152. feat = model(imgs)
  153. self.assertEqual(len(feat), 4)
  154. self.assertEqual(feat[0].shape, torch.Size([1, 256, 56, 56]))
  155. self.assertEqual(feat[1].shape, torch.Size([1, 512, 28, 28]))
  156. self.assertEqual(feat[2].shape, torch.Size([1, 1024, 14, 14]))
  157. self.assertEqual(feat[3].shape, torch.Size([1, 2048, 7, 7]))
  158. # Test SEResNet50 with layers 1, 2, 3 out forward
  159. model = SEResNet(50, out_indices=(0, 1, 2))
  160. model.init_weights()
  161. model.train()
  162. imgs = torch.randn(1, 3, 224, 224)
  163. feat = model(imgs)
  164. self.assertEqual(len(feat), 3)
  165. self.assertEqual(feat[0].shape, torch.Size([1, 256, 56, 56]))
  166. self.assertEqual(feat[1].shape, torch.Size([1, 512, 28, 28]))
  167. self.assertEqual(feat[2].shape, torch.Size([1, 1024, 14, 14]))
  168. # Test SEResNet50 with layers 3 (top feature maps) out forward
  169. model = SEResNet(50, out_indices=(3, ))
  170. model.init_weights()
  171. model.train()
  172. imgs = torch.randn(1, 3, 224, 224)
  173. feat = model(imgs)
  174. self.assertIsInstance(feat, tuple)
  175. self.assertEqual(feat[-1].shape, torch.Size([1, 2048, 7, 7]))
  176. # Test SEResNet50 with checkpoint forward
  177. model = SEResNet(50, out_indices=(0, 1, 2, 3), with_cp=True)
  178. for m in model.modules():
  179. if isinstance(m, SEBottleneck):
  180. self.assertTrue(m.with_cp)
  181. model.init_weights()
  182. model.train()
  183. imgs = torch.randn(1, 3, 224, 224)
  184. feat = model(imgs)
  185. self.assertEqual(len(feat), 4)
  186. self.assertEqual(feat[0].shape, torch.Size([1, 256, 56, 56]))
  187. self.assertEqual(feat[1].shape, torch.Size([1, 512, 28, 28]))
  188. self.assertEqual(feat[2].shape, torch.Size([1, 1024, 14, 14]))
  189. self.assertEqual(feat[3].shape, torch.Size([1, 2048, 7, 7]))
  190. # Test SEResNet50 zero initialization of residual
  191. model = SEResNet(50, out_indices=(0, 1, 2, 3), zero_init_residual=True)
  192. model.init_weights()
  193. for m in model.modules():
  194. if isinstance(m, SEBottleneck):
  195. self.assertTrue(self.all_zeros(m.norm3))
  196. model.train()
  197. imgs = torch.randn(1, 3, 224, 224)
  198. feat = model(imgs)
  199. self.assertEqual(len(feat), 4)
  200. self.assertEqual(feat[0].shape, torch.Size([1, 256, 56, 56]))
  201. self.assertEqual(feat[1].shape, torch.Size([1, 512, 28, 28]))
  202. self.assertEqual(feat[2].shape, torch.Size([1, 1024, 14, 14]))
  203. self.assertEqual(feat[3].shape, torch.Size([1, 2048, 7, 7]))