test_rsn.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmpose.models.backbones import RSN
  5. class TestRSN(TestCase):
  6. def test_rsn_backbone(self):
  7. with self.assertRaises(AssertionError):
  8. # RSN's num_stages should larger than 0
  9. RSN(num_stages=0)
  10. with self.assertRaises(AssertionError):
  11. # RSN's num_steps should larger than 1
  12. RSN(num_steps=1)
  13. with self.assertRaises(AssertionError):
  14. # RSN's num_units should larger than 1
  15. RSN(num_units=1)
  16. with self.assertRaises(AssertionError):
  17. # len(num_blocks) should equal num_units
  18. RSN(num_units=2, num_blocks=[2, 2, 2])
  19. # Test RSN's outputs
  20. model = RSN(num_stages=2, num_units=2, num_blocks=[2, 2])
  21. model.init_weights()
  22. model.train()
  23. imgs = torch.randn(1, 3, 511, 511)
  24. feat = model(imgs)
  25. self.assertEqual(len(feat), 2)
  26. self.assertEqual(len(feat[0]), 2)
  27. self.assertEqual(len(feat[1]), 2)
  28. self.assertEqual(feat[0][0].shape, torch.Size([1, 256, 64, 64]))
  29. self.assertEqual(feat[0][1].shape, torch.Size([1, 256, 128, 128]))
  30. self.assertEqual(feat[1][0].shape, torch.Size([1, 256, 64, 64]))
  31. self.assertEqual(feat[1][1].shape, torch.Size([1, 256, 128, 128]))