test_simcc_label.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import numpy as np
  4. from mmpose.codecs import SimCCLabel # noqa: F401
  5. from mmpose.registry import KEYPOINT_CODECS
  6. class TestSimCCLabel(TestCase):
  7. # name and configs of all test cases
  8. def setUp(self) -> None:
  9. self.configs = [
  10. (
  11. 'simcc gaussian',
  12. dict(
  13. type='SimCCLabel',
  14. input_size=(192, 256),
  15. smoothing_type='gaussian',
  16. sigma=6.0,
  17. simcc_split_ratio=2.0),
  18. ),
  19. (
  20. 'simcc smoothing',
  21. dict(
  22. type='SimCCLabel',
  23. input_size=(192, 256),
  24. smoothing_type='standard',
  25. sigma=5.0,
  26. simcc_split_ratio=3.0,
  27. label_smooth_weight=0.1),
  28. ),
  29. (
  30. 'simcc one-hot',
  31. dict(
  32. type='SimCCLabel',
  33. input_size=(192, 256),
  34. smoothing_type='standard',
  35. sigma=5.0,
  36. simcc_split_ratio=3.0),
  37. ),
  38. (
  39. 'simcc dark',
  40. dict(
  41. type='SimCCLabel',
  42. input_size=(192, 256),
  43. smoothing_type='gaussian',
  44. sigma=6.0,
  45. simcc_split_ratio=2.0,
  46. use_dark=True),
  47. ),
  48. (
  49. 'simcc separated sigmas',
  50. dict(
  51. type='SimCCLabel',
  52. input_size=(192, 256),
  53. smoothing_type='gaussian',
  54. sigma=(4.9, 5.66),
  55. simcc_split_ratio=2.0),
  56. ),
  57. ]
  58. # The bbox is usually padded so the keypoint will not be near the
  59. # boundary
  60. keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 2)) * [192, 256]
  61. keypoints = np.round(keypoints).astype(np.float32)
  62. keypoints_visible = np.ones((1, 17), dtype=np.float32)
  63. self.data = dict(
  64. keypoints=keypoints, keypoints_visible=keypoints_visible)
  65. def test_encode(self):
  66. keypoints = self.data['keypoints']
  67. keypoints_visible = self.data['keypoints_visible']
  68. for name, cfg in self.configs:
  69. codec = KEYPOINT_CODECS.build(cfg)
  70. encoded = codec.encode(keypoints, keypoints_visible)
  71. self.assertEqual(encoded['keypoint_x_labels'].shape,
  72. (1, 17, int(192 * codec.simcc_split_ratio)),
  73. f'Failed case: "{name}"')
  74. self.assertEqual(encoded['keypoint_y_labels'].shape,
  75. (1, 17, int(256 * codec.simcc_split_ratio)),
  76. f'Failed case: "{name}"')
  77. self.assertEqual(encoded['keypoint_weights'].shape, (1, 17),
  78. f'Failed case: "{name}"')
  79. def test_decode(self):
  80. for name, cfg in self.configs:
  81. codec = KEYPOINT_CODECS.build(cfg)
  82. simcc_x = np.random.rand(1, 17, int(192 * codec.simcc_split_ratio))
  83. simcc_y = np.random.rand(1, 17, int(256 * codec.simcc_split_ratio))
  84. keypoints, scores = codec.decode(simcc_x, simcc_y)
  85. self.assertEqual(keypoints.shape, (1, 17, 2),
  86. f'Failed case: "{name}"')
  87. self.assertEqual(scores.shape, (1, 17), f'Failed case: "{name}"')
  88. def test_cicular_verification(self):
  89. keypoints = self.data['keypoints']
  90. keypoints_visible = self.data['keypoints_visible']
  91. for name, cfg in self.configs:
  92. codec = KEYPOINT_CODECS.build(cfg)
  93. encoded = codec.encode(keypoints, keypoints_visible)
  94. keypoint_x_labels = encoded['keypoint_x_labels']
  95. keypoint_y_labels = encoded['keypoint_y_labels']
  96. _keypoints, _ = codec.decode(keypoint_x_labels, keypoint_y_labels)
  97. self.assertTrue(
  98. np.allclose(keypoints, _keypoints, atol=5.),
  99. f'Failed case: "{name}"')
  100. def test_errors(self):
  101. cfg = dict(
  102. type='SimCCLabel',
  103. input_size=(192, 256),
  104. smoothing_type='uniform',
  105. sigma=1.0,
  106. simcc_split_ratio=2.0)
  107. with self.assertRaisesRegex(ValueError,
  108. 'got invalid `smoothing_type`'):
  109. _ = KEYPOINT_CODECS.build(cfg)
  110. # invalid label_smooth_weight in smoothing
  111. cfg = dict(
  112. type='SimCCLabel',
  113. input_size=(192, 256),
  114. smoothing_type='standard',
  115. sigma=1.0,
  116. simcc_split_ratio=2.0,
  117. label_smooth_weight=1.1)
  118. with self.assertRaisesRegex(ValueError,
  119. '`label_smooth_weight` should be'):
  120. _ = KEYPOINT_CODECS.build(cfg)
  121. # invalid label_smooth_weight for gaussian
  122. cfg = dict(
  123. type='SimCCLabel',
  124. input_size=(192, 256),
  125. smoothing_type='gaussian',
  126. sigma=1.0,
  127. simcc_split_ratio=2.0,
  128. label_smooth_weight=0.1)
  129. with self.assertRaisesRegex(ValueError,
  130. 'is only used for `standard` mode.'):
  131. _ = KEYPOINT_CODECS.build(cfg)