test_regression_label.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import numpy as np
  4. from mmpose.codecs import RegressionLabel # noqa: F401
  5. from mmpose.registry import KEYPOINT_CODECS
  6. class TestRegressionLabel(TestCase):
  7. # name and configs of all test cases
  8. def setUp(self) -> None:
  9. self.configs = [
  10. (
  11. 'regression',
  12. dict(
  13. type='RegressionLabel',
  14. input_size=(192, 256),
  15. ),
  16. ),
  17. ]
  18. # The bbox is usually padded so the keypoint will not be near the
  19. # boundary
  20. keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 2)) * [192, 256]
  21. keypoints = np.round(keypoints).astype(np.float32)
  22. encoded_with_sigma = np.random.rand(1, 17, 4)
  23. encoded_wo_sigma = np.random.rand(1, 17, 2)
  24. keypoints_visible = np.ones((1, 17), dtype=np.float32)
  25. self.data = dict(
  26. keypoints=keypoints,
  27. keypoints_visible=keypoints_visible,
  28. encoded_with_sigma=encoded_with_sigma,
  29. encoded_wo_sigma=encoded_wo_sigma)
  30. def test_encode(self):
  31. keypoints = self.data['keypoints']
  32. keypoints_visible = self.data['keypoints_visible']
  33. for name, cfg in self.configs:
  34. codec = KEYPOINT_CODECS.build(cfg)
  35. encoded = codec.encode(keypoints, keypoints_visible)
  36. self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2),
  37. f'Failed case: "{name}"')
  38. self.assertEqual(encoded['keypoint_weights'].shape, (1, 17),
  39. f'Failed case: "{name}"')
  40. def test_decode(self):
  41. encoded_with_sigma = self.data['encoded_with_sigma']
  42. encoded_wo_sigma = self.data['encoded_wo_sigma']
  43. for name, cfg in self.configs:
  44. codec = KEYPOINT_CODECS.build(cfg)
  45. keypoints1, scores1 = codec.decode(encoded_with_sigma)
  46. keypoints2, scores2 = codec.decode(encoded_wo_sigma)
  47. self.assertEqual(keypoints1.shape, (1, 17, 2),
  48. f'Failed case: "{name}"')
  49. self.assertEqual(scores1.shape, (1, 17), f'Failed case: "{name}"')
  50. self.assertEqual(keypoints2.shape, (1, 17, 2),
  51. f'Failed case: "{name}"')
  52. self.assertEqual(scores2.shape, (1, 17), f'Failed case: "{name}"')
  53. def test_cicular_verification(self):
  54. keypoints = self.data['keypoints']
  55. keypoints_visible = self.data['keypoints_visible']
  56. for name, cfg in self.configs:
  57. codec = KEYPOINT_CODECS.build(cfg)
  58. encoded = codec.encode(keypoints, keypoints_visible)
  59. _keypoints, _ = codec.decode(encoded['keypoint_labels'])
  60. self.assertTrue(
  61. np.allclose(keypoints, _keypoints, atol=5.),
  62. f'Failed case: "{name}"')