test_megvii_heatmap.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import numpy as np
  4. from mmpose.codecs import MegviiHeatmap
  5. from mmpose.registry import KEYPOINT_CODECS
  6. class TestMegviiHeatmap(TestCase):
  7. def setUp(self) -> None:
  8. # name and configs of all test cases
  9. self.configs = [
  10. (
  11. 'megvii',
  12. dict(
  13. type='MegviiHeatmap',
  14. input_size=(192, 256),
  15. heatmap_size=(48, 64),
  16. kernel_size=11),
  17. ),
  18. ]
  19. # The bbox is usually padded so the keypoint will not be near the
  20. # boundary
  21. keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 2)) * [192, 256]
  22. keypoints = np.round(keypoints).astype(np.float32)
  23. keypoints_visible = np.ones((1, 17), dtype=np.float32)
  24. heatmaps = np.random.rand(17, 64, 48).astype(np.float32)
  25. self.data = dict(
  26. keypoints=keypoints,
  27. keypoints_visible=keypoints_visible,
  28. heatmaps=heatmaps)
  29. def test_encode(self):
  30. keypoints = self.data['keypoints']
  31. keypoints_visible = self.data['keypoints_visible']
  32. for name, cfg in self.configs:
  33. codec = KEYPOINT_CODECS.build(cfg)
  34. encoded = codec.encode(keypoints, keypoints_visible)
  35. self.assertEqual(encoded['heatmaps'].shape, (17, 64, 48),
  36. f'Failed case: "{name}"')
  37. self.assertEqual(encoded['keypoint_weights'].shape,
  38. (1, 17)), f'Failed case: "{name}"'
  39. def test_decode(self):
  40. heatmaps = self.data['heatmaps']
  41. for name, cfg in self.configs:
  42. codec = KEYPOINT_CODECS.build(cfg)
  43. keypoints, scores = codec.decode(heatmaps)
  44. self.assertEqual(keypoints.shape, (1, 17, 2),
  45. f'Failed case: "{name}"')
  46. self.assertEqual(scores.shape, (1, 17), f'Failed case: "{name}"')
  47. def test_cicular_verification(self):
  48. keypoints = self.data['keypoints']
  49. keypoints_visible = self.data['keypoints_visible']
  50. for name, cfg in self.configs:
  51. codec = KEYPOINT_CODECS.build(cfg)
  52. encoded = codec.encode(keypoints, keypoints_visible)
  53. _keypoints, _ = codec.decode(encoded['heatmaps'])
  54. self.assertTrue(
  55. np.allclose(keypoints, _keypoints, atol=5.),
  56. f'Failed case: "{name}"')
  57. def test_errors(self):
  58. # multiple instance
  59. codec = MegviiHeatmap(
  60. input_size=(192, 256), heatmap_size=(48, 64), kernel_size=11)
  61. keypoints = np.random.rand(2, 17, 2)
  62. keypoints_visible = np.random.rand(2, 17)
  63. with self.assertRaisesRegex(AssertionError,
  64. 'only support single-instance'):
  65. codec.encode(keypoints, keypoints_visible)