test_resnext.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmpose.models.backbones import ResNeXt
  5. from mmpose.models.backbones.resnext import Bottleneck as BottleneckX
  6. class TestResnext(TestCase):
  7. def test_bottleneck(self):
  8. with self.assertRaises(AssertionError):
  9. # Style must be in ['pytorch', 'caffe']
  10. BottleneckX(
  11. 64, 64, groups=32, width_per_group=4, style='tensorflow')
  12. # Test ResNeXt Bottleneck structure
  13. block = BottleneckX(
  14. 64, 256, groups=32, width_per_group=4, stride=2, style='pytorch')
  15. self.assertEqual(block.conv2.stride, (2, 2))
  16. self.assertEqual(block.conv2.groups, 32)
  17. self.assertEqual(block.conv2.out_channels, 128)
  18. # Test ResNeXt Bottleneck forward
  19. block = BottleneckX(
  20. 64, 64, base_channels=16, groups=32, width_per_group=4)
  21. x = torch.randn(1, 64, 56, 56)
  22. x_out = block(x)
  23. self.assertEqual(x_out.shape, torch.Size([1, 64, 56, 56]))
  24. def test_resnext(self):
  25. with self.assertRaises(KeyError):
  26. # ResNeXt depth should be in [50, 101, 152]
  27. ResNeXt(depth=18)
  28. # Test ResNeXt with group 32, width_per_group 4
  29. model = ResNeXt(
  30. depth=50, groups=32, width_per_group=4, out_indices=(0, 1, 2, 3))
  31. for m in model.modules():
  32. if isinstance(m, BottleneckX):
  33. self.assertEqual(m.conv2.groups, 32)
  34. model.init_weights()
  35. model.train()
  36. imgs = torch.randn(1, 3, 224, 224)
  37. feat = model(imgs)
  38. self.assertEqual(len(feat), 4)
  39. self.assertEqual(feat[0].shape, torch.Size([1, 256, 56, 56]))
  40. self.assertEqual(feat[1].shape, torch.Size([1, 512, 28, 28]))
  41. self.assertEqual(feat[2].shape, torch.Size([1, 1024, 14, 14]))
  42. self.assertEqual(feat[3].shape, torch.Size([1, 2048, 7, 7]))
  43. # Test ResNeXt with layers 3 out forward
  44. model = ResNeXt(
  45. depth=50, groups=32, width_per_group=4, out_indices=(3, ))
  46. for m in model.modules():
  47. if isinstance(m, BottleneckX):
  48. self.assertEqual(m.conv2.groups, 32)
  49. model.init_weights()
  50. model.train()
  51. imgs = torch.randn(1, 3, 224, 224)
  52. feat = model(imgs)
  53. self.assertEqual(len(feat), 1)
  54. self.assertEqual(feat[-1].shape, torch.Size([1, 2048, 7, 7]))