test_cpm.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmpose.models.backbones import CPM
  5. from mmpose.models.backbones.cpm import CpmBlock
  6. class TestCPM(TestCase):
  7. def test_cpm_block(self):
  8. with self.assertRaises(AssertionError):
  9. # len(channels) == len(kernels)
  10. CpmBlock(
  11. 3, channels=[3, 3, 3], kernels=[
  12. 1,
  13. ])
  14. # Test CPM Block
  15. model = CpmBlock(3, channels=[3, 3, 3], kernels=[1, 1, 1])
  16. model.train()
  17. imgs = torch.randn(1, 3, 10, 10)
  18. feat = model(imgs)
  19. self.assertEqual(feat.shape, torch.Size([1, 3, 10, 10]))
  20. def test_cpm_backbone(self):
  21. with self.assertRaises(AssertionError):
  22. # CPM's num_stacks should larger than 0
  23. CPM(in_channels=3, out_channels=17, num_stages=-1)
  24. with self.assertRaises(AssertionError):
  25. # CPM's in_channels should be 3
  26. CPM(in_channels=2, out_channels=17)
  27. # Test CPM
  28. model = CPM(in_channels=3, out_channels=17, num_stages=1)
  29. model.init_weights()
  30. model.train()
  31. imgs = torch.randn(1, 3, 256, 192)
  32. feat = model(imgs)
  33. self.assertEqual(len(feat), 1)
  34. self.assertEqual(feat[0].shape, torch.Size([1, 17, 32, 24]))
  35. imgs = torch.randn(1, 3, 384, 288)
  36. feat = model(imgs)
  37. self.assertEqual(len(feat), 1)
  38. self.assertEqual(feat[0].shape, torch.Size([1, 17, 48, 36]))
  39. imgs = torch.randn(1, 3, 368, 368)
  40. feat = model(imgs)
  41. self.assertEqual(len(feat), 1)
  42. self.assertEqual(feat[0].shape, torch.Size([1, 17, 46, 46]))
  43. # Test CPM multi-stages
  44. model = CPM(in_channels=3, out_channels=17, num_stages=2)
  45. model.init_weights()
  46. model.train()
  47. imgs = torch.randn(1, 3, 368, 368)
  48. feat = model(imgs)
  49. self.assertEqual(len(feat), 2)
  50. self.assertEqual(feat[0].shape, torch.Size([1, 17, 46, 46]))
  51. self.assertEqual(feat[1].shape, torch.Size([1, 17, 46, 46]))