test_integral_regression_label.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import numpy as np
  4. from mmpose.codecs import IntegralRegressionLabel # noqa: F401
  5. from mmpose.registry import KEYPOINT_CODECS
  6. class TestRegressionLabel(TestCase):
  7. # name and configs of all test cases
  8. def setUp(self) -> None:
  9. self.configs = [
  10. (
  11. 'ipr',
  12. dict(
  13. type='IntegralRegressionLabel',
  14. input_size=(192, 256),
  15. heatmap_size=(48, 64),
  16. sigma=2),
  17. ),
  18. ]
  19. # The bbox is usually padded so the keypoint will not be near the
  20. # boundary
  21. keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 2)) * [192, 256]
  22. keypoints = np.round(keypoints).astype(np.float32)
  23. heatmaps = np.random.rand(17, 64, 48).astype(np.float32)
  24. encoded_wo_sigma = np.random.rand(1, 17, 2)
  25. keypoints_visible = np.ones((1, 17), dtype=np.float32)
  26. self.data = dict(
  27. keypoints=keypoints,
  28. keypoints_visible=keypoints_visible,
  29. heatmaps=heatmaps,
  30. encoded_wo_sigma=encoded_wo_sigma)
  31. def test_encode(self):
  32. keypoints = self.data['keypoints']
  33. keypoints_visible = self.data['keypoints_visible']
  34. for name, cfg in self.configs:
  35. codec = KEYPOINT_CODECS.build(cfg)
  36. encoded = codec.encode(keypoints, keypoints_visible)
  37. heatmaps = encoded['heatmaps']
  38. keypoint_labels = encoded['keypoint_labels']
  39. keypoint_weights = encoded['keypoint_weights']
  40. self.assertEqual(heatmaps.shape, (17, 64, 48),
  41. f'Failed case: "{name}"')
  42. self.assertEqual(keypoint_labels.shape, (1, 17, 2),
  43. f'Failed case: "{name}"')
  44. self.assertEqual(keypoint_weights.shape, (1, 17),
  45. f'Failed case: "{name}"')
  46. def test_decode(self):
  47. encoded_wo_sigma = self.data['encoded_wo_sigma']
  48. for name, cfg in self.configs:
  49. codec = KEYPOINT_CODECS.build(cfg)
  50. keypoints, scores = codec.decode(encoded_wo_sigma)
  51. self.assertEqual(keypoints.shape, (1, 17, 2),
  52. f'Failed case: "{name}"')
  53. self.assertEqual(scores.shape, (1, 17), f'Failed case: "{name}"')
  54. def test_cicular_verification(self):
  55. keypoints = self.data['keypoints']
  56. keypoints_visible = self.data['keypoints_visible']
  57. for name, cfg in self.configs:
  58. codec = KEYPOINT_CODECS.build(cfg)
  59. encoded = codec.encode(keypoints, keypoints_visible)
  60. keypoint_labels = encoded['keypoint_labels']
  61. _keypoints, _ = codec.decode(keypoint_labels)
  62. self.assertTrue(
  63. np.allclose(keypoints, _keypoints, atol=5.),
  64. f'Failed case: "{name}"')