test_spr.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import numpy as np
  4. from mmpose.codecs import SPR
  5. from mmpose.registry import KEYPOINT_CODECS
  6. from mmpose.testing import get_coco_sample
  7. from mmpose.utils.tensor_utils import to_numpy, to_tensor
  8. class TestSPR(TestCase):
  9. def setUp(self) -> None:
  10. pass
  11. def _make_multi_instance_data(self, data):
  12. keypoints = data['keypoints']
  13. keypoints_visible = data['keypoints_visible']
  14. keypoints_visible[..., 0] = 0
  15. keypoints_outside = keypoints - keypoints.max(axis=-1, keepdims=True)
  16. keypoints_outside_visible = np.zeros(keypoints_visible.shape)
  17. keypoint_overlap = keypoints.mean(
  18. axis=-1, keepdims=True) + 0.8 * (
  19. keypoints - keypoints.mean(axis=-1, keepdims=True))
  20. keypoint_overlap_visible = keypoints_visible
  21. data['keypoints'] = np.concatenate(
  22. (keypoints, keypoints_outside, keypoint_overlap), axis=0)
  23. data['keypoints_visible'] = np.concatenate(
  24. (keypoints_visible, keypoints_outside_visible,
  25. keypoint_overlap_visible),
  26. axis=0)
  27. return data
  28. def test_build(self):
  29. cfg = dict(
  30. type='SPR',
  31. input_size=(512, 512),
  32. heatmap_size=(128, 128),
  33. sigma=4,
  34. )
  35. codec = KEYPOINT_CODECS.build(cfg)
  36. self.assertIsInstance(codec, SPR)
  37. def test_encode(self):
  38. data = get_coco_sample(img_shape=(512, 512), num_instances=1)
  39. data = self._make_multi_instance_data(data)
  40. # w/o keypoint heatmaps
  41. codec = SPR(
  42. input_size=(512, 512),
  43. heatmap_size=(128, 128),
  44. sigma=4,
  45. )
  46. encoded = codec.encode(data['keypoints'], data['keypoints_visible'])
  47. heatmaps = encoded['heatmaps']
  48. displacements = encoded['displacements']
  49. heatmap_weights = encoded['heatmap_weights']
  50. displacement_weights = encoded['displacement_weights']
  51. self.assertEqual(heatmaps.shape, (1, 128, 128))
  52. self.assertEqual(heatmap_weights.shape, (1, 128, 128))
  53. self.assertEqual(displacements.shape, (34, 128, 128))
  54. self.assertEqual(displacement_weights.shape, (34, 128, 128))
  55. # w/ keypoint heatmaps
  56. with self.assertRaises(AssertionError):
  57. codec = SPR(
  58. input_size=(512, 512),
  59. heatmap_size=(128, 128),
  60. sigma=4,
  61. generate_keypoint_heatmaps=True,
  62. )
  63. codec = SPR(
  64. input_size=(512, 512),
  65. heatmap_size=(128, 128),
  66. sigma=(4, 2),
  67. generate_keypoint_heatmaps=True,
  68. )
  69. encoded = codec.encode(data['keypoints'], data['keypoints_visible'])
  70. heatmaps = encoded['heatmaps']
  71. displacements = encoded['displacements']
  72. heatmap_weights = encoded['heatmap_weights']
  73. displacement_weights = encoded['displacement_weights']
  74. self.assertEqual(heatmaps.shape, (18, 128, 128))
  75. self.assertEqual(heatmap_weights.shape, (18, 128, 128))
  76. self.assertEqual(displacements.shape, (34, 128, 128))
  77. self.assertEqual(displacement_weights.shape, (34, 128, 128))
  78. # root_type
  79. with self.assertRaises(ValueError):
  80. codec = SPR(
  81. input_size=(512, 512),
  82. heatmap_size=(128, 128),
  83. sigma=(4, ),
  84. root_type='box_center',
  85. )
  86. encoded = codec.encode(data['keypoints'],
  87. data['keypoints_visible'])
  88. codec = SPR(
  89. input_size=(512, 512),
  90. heatmap_size=(128, 128),
  91. sigma=(4, ),
  92. root_type='bbox_center',
  93. )
  94. encoded = codec.encode(data['keypoints'], data['keypoints_visible'])
  95. heatmaps = encoded['heatmaps']
  96. displacements = encoded['displacements']
  97. heatmap_weights = encoded['heatmap_weights']
  98. displacement_weights = encoded['displacement_weights']
  99. self.assertEqual(heatmaps.shape, (1, 128, 128))
  100. self.assertEqual(heatmap_weights.shape, (1, 128, 128))
  101. self.assertEqual(displacements.shape, (34, 128, 128))
  102. self.assertEqual(displacement_weights.shape, (34, 128, 128))
  103. def test_decode(self):
  104. data = get_coco_sample(img_shape=(512, 512), num_instances=1)
  105. # decode w/o keypoint heatmaps
  106. codec = SPR(
  107. input_size=(512, 512),
  108. heatmap_size=(128, 128),
  109. sigma=(4, ),
  110. generate_keypoint_heatmaps=False,
  111. )
  112. encoded = codec.encode(data['keypoints'], data['keypoints_visible'])
  113. decoded = codec.decode(
  114. to_tensor(encoded['heatmaps']),
  115. to_tensor(encoded['displacements']))
  116. keypoints, (root_scores, keypoint_scores) = decoded
  117. self.assertIsNone(keypoint_scores)
  118. self.assertEqual(keypoints.shape, data['keypoints'].shape)
  119. self.assertEqual(root_scores.shape, data['keypoints'].shape[:1])
  120. # decode w/ keypoint heatmaps
  121. codec = SPR(
  122. input_size=(512, 512),
  123. heatmap_size=(128, 128),
  124. sigma=(4, 2),
  125. generate_keypoint_heatmaps=True,
  126. )
  127. encoded = codec.encode(data['keypoints'], data['keypoints_visible'])
  128. decoded = codec.decode(
  129. to_tensor(encoded['heatmaps']),
  130. to_tensor(encoded['displacements']))
  131. keypoints, (root_scores, keypoint_scores) = decoded
  132. self.assertIsNotNone(keypoint_scores)
  133. self.assertEqual(keypoints.shape, data['keypoints'].shape)
  134. self.assertEqual(root_scores.shape, data['keypoints'].shape[:1])
  135. self.assertEqual(keypoint_scores.shape, data['keypoints'].shape[:2])
  136. def test_cicular_verification(self):
  137. data = get_coco_sample(img_shape=(512, 512), num_instances=1)
  138. codec = SPR(
  139. input_size=(512, 512),
  140. heatmap_size=(128, 128),
  141. sigma=(4, ),
  142. generate_keypoint_heatmaps=False,
  143. )
  144. encoded = codec.encode(data['keypoints'], data['keypoints_visible'])
  145. decoded = codec.decode(
  146. to_tensor(encoded['heatmaps']),
  147. to_tensor(encoded['displacements']))
  148. keypoints, _ = decoded
  149. self.assertTrue(
  150. np.allclose(to_numpy(keypoints), data['keypoints'], atol=5.))