test_msra_heatmap.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import numpy as np
  4. from mmpose.codecs import MSRAHeatmap
  5. from mmpose.registry import KEYPOINT_CODECS
  6. class TestMSRAHeatmap(TestCase):
  7. def setUp(self) -> None:
  8. # name and configs of all test cases
  9. self.configs = [
  10. (
  11. 'msra',
  12. dict(
  13. type='MSRAHeatmap',
  14. input_size=(192, 256),
  15. heatmap_size=(48, 64),
  16. sigma=2.0),
  17. ),
  18. (
  19. 'msra+dark',
  20. dict(
  21. type='MSRAHeatmap',
  22. input_size=(192, 256),
  23. heatmap_size=(48, 64),
  24. sigma=2.0,
  25. unbiased=True),
  26. ),
  27. ]
  28. # The bbox is usually padded so the keypoint will not be near the
  29. # boundary
  30. keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 2)) * [192, 256]
  31. keypoints = np.round(keypoints).astype(np.float32)
  32. keypoints_visible = np.ones((1, 17), dtype=np.float32)
  33. heatmaps = np.random.rand(17, 64, 48).astype(np.float32)
  34. self.data = dict(
  35. keypoints=keypoints,
  36. keypoints_visible=keypoints_visible,
  37. heatmaps=heatmaps)
  38. def test_encode(self):
  39. keypoints = self.data['keypoints']
  40. keypoints_visible = self.data['keypoints_visible']
  41. for name, cfg in self.configs:
  42. codec = KEYPOINT_CODECS.build(cfg)
  43. encoded = codec.encode(keypoints, keypoints_visible)
  44. self.assertEqual(encoded['heatmaps'].shape, (17, 64, 48),
  45. f'Failed case: "{name}"')
  46. self.assertEqual(encoded['keypoint_weights'].shape,
  47. (1, 17)), f'Failed case: "{name}"'
  48. def test_decode(self):
  49. heatmaps = self.data['heatmaps']
  50. for name, cfg in self.configs:
  51. codec = KEYPOINT_CODECS.build(cfg)
  52. keypoints, scores = codec.decode(heatmaps)
  53. self.assertEqual(keypoints.shape, (1, 17, 2),
  54. f'Failed case: "{name}"')
  55. self.assertEqual(scores.shape, (1, 17), f'Failed case: "{name}"')
  56. def test_cicular_verification(self):
  57. keypoints = self.data['keypoints']
  58. keypoints_visible = self.data['keypoints_visible']
  59. for name, cfg in self.configs:
  60. codec = KEYPOINT_CODECS.build(cfg)
  61. encoded = codec.encode(keypoints, keypoints_visible)
  62. _keypoints, _ = codec.decode(encoded['heatmaps'])
  63. self.assertTrue(
  64. np.allclose(keypoints, _keypoints, atol=5.),
  65. f'Failed case: "{name}"')
  66. def test_errors(self):
  67. # multiple instance
  68. codec = MSRAHeatmap(
  69. input_size=(192, 256), heatmap_size=(48, 64), sigma=2.0)
  70. keypoints = np.random.rand(2, 17, 2)
  71. keypoints_visible = np.random.rand(2, 17)
  72. with self.assertRaisesRegex(AssertionError,
  73. 'only support single-instance'):
  74. codec.encode(keypoints, keypoints_visible)