test_mspn.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmpose.models.backbones import MSPN
  5. class TestMSPN(TestCase):
  6. def test_mspn_backbone(self):
  7. with self.assertRaises(AssertionError):
  8. # MSPN's num_stages should larger than 0
  9. MSPN(num_stages=0)
  10. with self.assertRaises(AssertionError):
  11. # MSPN's num_units should larger than 1
  12. MSPN(num_units=1)
  13. with self.assertRaises(AssertionError):
  14. # len(num_blocks) should equal num_units
  15. MSPN(num_units=2, num_blocks=[2, 2, 2])
  16. # Test MSPN's outputs
  17. model = MSPN(num_stages=2, num_units=2, num_blocks=[2, 2])
  18. model.init_weights()
  19. model.train()
  20. imgs = torch.randn(1, 3, 511, 511)
  21. feat = model(imgs)
  22. self.assertEqual(len(feat), 2)
  23. self.assertEqual(len(feat[0]), 2)
  24. self.assertEqual(len(feat[1]), 2)
  25. self.assertEqual(feat[0][0].shape, torch.Size([1, 256, 64, 64]))
  26. self.assertEqual(feat[0][1].shape, torch.Size([1, 256, 128, 128]))
  27. self.assertEqual(feat[1][0].shape, torch.Size([1, 256, 64, 64]))
  28. self.assertEqual(feat[1][1].shape, torch.Size([1, 256, 128, 128]))