test_seresnext.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmpose.models.backbones import SEResNeXt
  5. from mmpose.models.backbones.seresnext import SEBottleneck as SEBottleneckX
  6. class TestSEResnext(TestCase):
  7. def test_bottleneck(self):
  8. with self.assertRaises(AssertionError):
  9. # Style must be in ['pytorch', 'caffe']
  10. SEBottleneckX(
  11. 64, 64, groups=32, width_per_group=4, style='tensorflow')
  12. # Test SEResNeXt Bottleneck structure
  13. block = SEBottleneckX(
  14. 64, 256, groups=32, width_per_group=4, stride=2, style='pytorch')
  15. self.assertEqual(block.width_per_group, 4)
  16. self.assertEqual(block.conv2.stride, (2, 2))
  17. self.assertEqual(block.conv2.groups, 32)
  18. self.assertEqual(block.conv2.out_channels, 128)
  19. self.assertEqual(block.conv2.out_channels, block.mid_channels)
  20. # Test SEResNeXt Bottleneck structure (groups=1)
  21. block = SEBottleneckX(
  22. 64, 256, groups=1, width_per_group=4, stride=2, style='pytorch')
  23. self.assertEqual(block.conv2.stride, (2, 2))
  24. self.assertEqual(block.conv2.groups, 1)
  25. self.assertEqual(block.conv2.out_channels, 64)
  26. self.assertEqual(block.mid_channels, 64)
  27. self.assertEqual(block.conv2.out_channels, block.mid_channels)
  28. # Test SEResNeXt Bottleneck forward
  29. block = SEBottleneckX(
  30. 64, 64, base_channels=16, groups=32, width_per_group=4)
  31. x = torch.randn(1, 64, 56, 56)
  32. x_out = block(x)
  33. self.assertEqual(x_out.shape, torch.Size([1, 64, 56, 56]))
  34. def test_seresnext(self):
  35. with self.assertRaises(KeyError):
  36. # SEResNeXt depth should be in [50, 101, 152]
  37. SEResNeXt(depth=18)
  38. # Test SEResNeXt with group 32, width_per_group 4
  39. model = SEResNeXt(
  40. depth=50, groups=32, width_per_group=4, out_indices=(0, 1, 2, 3))
  41. for m in model.modules():
  42. if isinstance(m, SEBottleneckX):
  43. self.assertEqual(m.conv2.groups, 32)
  44. model.init_weights()
  45. model.train()
  46. imgs = torch.randn(1, 3, 224, 224)
  47. feat = model(imgs)
  48. self.assertEqual(len(feat), 4)
  49. self.assertEqual(feat[0].shape, torch.Size([1, 256, 56, 56]))
  50. self.assertEqual(feat[1].shape, torch.Size([1, 512, 28, 28]))
  51. self.assertEqual(feat[2].shape, torch.Size([1, 1024, 14, 14]))
  52. self.assertEqual(feat[3].shape, torch.Size([1, 2048, 7, 7]))
  53. # Test SEResNeXt with layers 3 out forward
  54. model = SEResNeXt(
  55. depth=50, groups=32, width_per_group=4, out_indices=(3, ))
  56. for m in model.modules():
  57. if isinstance(m, SEBottleneckX):
  58. self.assertEqual(m.conv2.groups, 32)
  59. model.init_weights()
  60. model.train()
  61. imgs = torch.randn(1, 3, 224, 224)
  62. feat = model(imgs)
  63. self.assertIsInstance(feat, tuple)
  64. self.assertEqual(feat[-1].shape, torch.Size([1, 2048, 7, 7]))