test_resnest.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmpose.models.backbones import ResNeSt
  5. from mmpose.models.backbones.resnest import Bottleneck as BottleneckS
  6. class TestResnest(TestCase):
  7. def test_bottleneck(self):
  8. with self.assertRaises(AssertionError):
  9. # Style must be in ['pytorch', 'caffe']
  10. BottleneckS(
  11. 64, 64, radix=2, reduction_factor=4, style='tensorflow')
  12. # Test ResNeSt Bottleneck structure
  13. block = BottleneckS(
  14. 64, 256, radix=2, reduction_factor=4, stride=2, style='pytorch')
  15. self.assertEqual(block.avd_layer.stride, 2)
  16. self.assertEqual(block.conv2.channels, 64)
  17. # Test ResNeSt Bottleneck forward
  18. block = BottleneckS(64, 64, radix=2, reduction_factor=4)
  19. x = torch.randn(2, 64, 56, 56)
  20. x_out = block(x)
  21. self.assertEqual(x_out.shape, torch.Size([2, 64, 56, 56]))
  22. def test_resnest(self):
  23. with self.assertRaises(KeyError):
  24. # ResNeSt depth should be in [50, 101, 152, 200]
  25. ResNeSt(depth=18)
  26. # Test ResNeSt with radix 2, reduction_factor 4
  27. model = ResNeSt(
  28. depth=50, radix=2, reduction_factor=4, out_indices=(0, 1, 2, 3))
  29. model.init_weights()
  30. model.train()
  31. imgs = torch.randn(2, 3, 224, 224)
  32. feat = model(imgs)
  33. self.assertEqual(len(feat), 4)
  34. self.assertEqual(feat[0].shape, torch.Size([2, 256, 56, 56]))
  35. self.assertEqual(feat[1].shape, torch.Size([2, 512, 28, 28]))
  36. self.assertEqual(feat[2].shape, torch.Size([2, 1024, 14, 14]))
  37. self.assertEqual(feat[3].shape, torch.Size([2, 2048, 7, 7]))