test_associative_embedding.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from itertools import product
  3. from unittest import TestCase
  4. import numpy as np
  5. import torch
  6. from munkres import Munkres
  7. from mmpose.codecs import AssociativeEmbedding
  8. from mmpose.registry import KEYPOINT_CODECS
  9. from mmpose.testing import get_coco_sample
  10. class TestAssociativeEmbedding(TestCase):
  11. def setUp(self) -> None:
  12. self.decode_keypoint_order = [
  13. 0, 1, 2, 3, 4, 5, 6, 11, 12, 7, 8, 9, 10, 13, 14, 15, 16
  14. ]
  15. def test_build(self):
  16. cfg = dict(
  17. type='AssociativeEmbedding',
  18. input_size=(256, 256),
  19. heatmap_size=(64, 64),
  20. use_udp=False,
  21. decode_keypoint_order=self.decode_keypoint_order,
  22. )
  23. codec = KEYPOINT_CODECS.build(cfg)
  24. self.assertIsInstance(codec, AssociativeEmbedding)
  25. def test_encode(self):
  26. data = get_coco_sample(img_shape=(256, 256), num_instances=1)
  27. # w/o UDP
  28. codec = AssociativeEmbedding(
  29. input_size=(256, 256),
  30. heatmap_size=(64, 64),
  31. use_udp=False,
  32. decode_keypoint_order=self.decode_keypoint_order)
  33. encoded = codec.encode(data['keypoints'], data['keypoints_visible'])
  34. heatmaps = encoded['heatmaps']
  35. keypoint_indices = encoded['keypoint_indices']
  36. keypoint_weights = encoded['keypoint_weights']
  37. self.assertEqual(heatmaps.shape, (17, 64, 64))
  38. self.assertEqual(keypoint_indices.shape, (1, 17, 2))
  39. self.assertEqual(keypoint_weights.shape, (1, 17))
  40. for k in range(heatmaps.shape[0]):
  41. index_expected = np.argmax(heatmaps[k])
  42. index_encoded = keypoint_indices[0, k, 0]
  43. self.assertEqual(index_expected, index_encoded)
  44. # w/ UDP
  45. codec = AssociativeEmbedding(
  46. input_size=(256, 256),
  47. heatmap_size=(64, 64),
  48. use_udp=True,
  49. decode_keypoint_order=self.decode_keypoint_order)
  50. encoded = codec.encode(data['keypoints'], data['keypoints_visible'])
  51. heatmaps = encoded['heatmaps']
  52. keypoint_indices = encoded['keypoint_indices']
  53. keypoint_weights = encoded['keypoint_weights']
  54. self.assertEqual(heatmaps.shape, (17, 64, 64))
  55. self.assertEqual(keypoint_indices.shape, (1, 17, 2))
  56. self.assertEqual(keypoint_weights.shape, (1, 17))
  57. for k in range(heatmaps.shape[0]):
  58. index_expected = np.argmax(heatmaps[k])
  59. index_encoded = keypoint_indices[0, k, 0]
  60. self.assertEqual(index_expected, index_encoded)
  61. def _get_tags(self,
  62. heatmaps,
  63. keypoint_indices,
  64. tag_per_keypoint: bool,
  65. tag_dim: int = 1):
  66. K, H, W = heatmaps.shape
  67. N = keypoint_indices.shape[0]
  68. if tag_per_keypoint:
  69. tags = np.zeros((K * tag_dim, H, W), dtype=np.float32)
  70. else:
  71. tags = np.zeros((tag_dim, H, W), dtype=np.float32)
  72. for n, k in product(range(N), range(K)):
  73. y, x = np.unravel_index(keypoint_indices[n, k, 0], (H, W))
  74. if tag_per_keypoint:
  75. tags[k::K, y, x] = n
  76. else:
  77. tags[:, y, x] = n
  78. return tags
  79. def _sort_preds(self, keypoints_pred, scores_pred, keypoints_gt):
  80. """Sort multi-instance predictions to best match the ground-truth.
  81. Args:
  82. keypoints_pred (np.ndarray): predictions in shape (N, K, D)
  83. scores (np.ndarray): predictions in shape (N, K)
  84. keypoints_gt (np.ndarray): ground-truth in shape (N, K, D)
  85. Returns:
  86. np.ndarray: Sorted predictions
  87. """
  88. assert keypoints_gt.shape == keypoints_pred.shape
  89. costs = np.linalg.norm(
  90. keypoints_gt[None] - keypoints_pred[:, None], ord=2,
  91. axis=3).mean(axis=2)
  92. match = Munkres().compute(costs)
  93. keypoints_pred_sorted = np.zeros_like(keypoints_pred)
  94. scores_pred_sorted = np.zeros_like(scores_pred)
  95. for i, j in match:
  96. keypoints_pred_sorted[i] = keypoints_pred[j]
  97. scores_pred_sorted[i] = scores_pred[j]
  98. return keypoints_pred_sorted, scores_pred_sorted
  99. def test_decode(self):
  100. data = get_coco_sample(
  101. img_shape=(256, 256), num_instances=2, non_occlusion=True)
  102. # w/o UDP
  103. codec = AssociativeEmbedding(
  104. input_size=(256, 256),
  105. heatmap_size=(64, 64),
  106. use_udp=False,
  107. decode_keypoint_order=self.decode_keypoint_order)
  108. encoded = codec.encode(data['keypoints'], data['keypoints_visible'])
  109. heatmaps = encoded['heatmaps']
  110. keypoint_indices = encoded['keypoint_indices']
  111. tags = self._get_tags(
  112. heatmaps, keypoint_indices, tag_per_keypoint=True)
  113. # to Tensor
  114. batch_heatmaps = torch.from_numpy(heatmaps[None])
  115. batch_tags = torch.from_numpy(tags[None])
  116. batch_keypoints, batch_keypoint_scores = codec.batch_decode(
  117. batch_heatmaps, batch_tags)
  118. self.assertIsInstance(batch_keypoints, list)
  119. self.assertIsInstance(batch_keypoint_scores, list)
  120. self.assertEqual(len(batch_keypoints), 1)
  121. self.assertEqual(len(batch_keypoint_scores), 1)
  122. keypoints, scores = self._sort_preds(batch_keypoints[0],
  123. batch_keypoint_scores[0],
  124. data['keypoints'])
  125. self.assertIsInstance(keypoints, np.ndarray)
  126. self.assertIsInstance(scores, np.ndarray)
  127. self.assertEqual(keypoints.shape, (2, 17, 2))
  128. self.assertEqual(scores.shape, (2, 17))
  129. self.assertTrue(np.allclose(keypoints, data['keypoints'], atol=4.0))
  130. # w/o UDP, tag_imd=2
  131. codec = AssociativeEmbedding(
  132. input_size=(256, 256),
  133. heatmap_size=(64, 64),
  134. use_udp=False,
  135. decode_keypoint_order=self.decode_keypoint_order)
  136. encoded = codec.encode(data['keypoints'], data['keypoints_visible'])
  137. heatmaps = encoded['heatmaps']
  138. keypoint_indices = encoded['keypoint_indices']
  139. tags = self._get_tags(
  140. heatmaps, keypoint_indices, tag_per_keypoint=True, tag_dim=2)
  141. # to Tensor
  142. batch_heatmaps = torch.from_numpy(heatmaps[None])
  143. batch_tags = torch.from_numpy(tags[None])
  144. batch_keypoints, batch_keypoint_scores = codec.batch_decode(
  145. batch_heatmaps, batch_tags)
  146. self.assertIsInstance(batch_keypoints, list)
  147. self.assertIsInstance(batch_keypoint_scores, list)
  148. self.assertEqual(len(batch_keypoints), 1)
  149. self.assertEqual(len(batch_keypoint_scores), 1)
  150. keypoints, scores = self._sort_preds(batch_keypoints[0],
  151. batch_keypoint_scores[0],
  152. data['keypoints'])
  153. self.assertIsInstance(keypoints, np.ndarray)
  154. self.assertIsInstance(scores, np.ndarray)
  155. self.assertEqual(keypoints.shape, (2, 17, 2))
  156. self.assertEqual(scores.shape, (2, 17))
  157. self.assertTrue(np.allclose(keypoints, data['keypoints'], atol=4.0))
  158. # w/ UDP
  159. codec = AssociativeEmbedding(
  160. input_size=(256, 256),
  161. heatmap_size=(64, 64),
  162. use_udp=True,
  163. decode_keypoint_order=self.decode_keypoint_order)
  164. encoded = codec.encode(data['keypoints'], data['keypoints_visible'])
  165. heatmaps = encoded['heatmaps']
  166. keypoint_indices = encoded['keypoint_indices']
  167. tags = self._get_tags(
  168. heatmaps, keypoint_indices, tag_per_keypoint=True)
  169. # to Tensor
  170. batch_heatmaps = torch.from_numpy(heatmaps[None])
  171. batch_tags = torch.from_numpy(tags[None])
  172. batch_keypoints, batch_keypoint_scores = codec.batch_decode(
  173. batch_heatmaps, batch_tags)
  174. self.assertIsInstance(batch_keypoints, list)
  175. self.assertIsInstance(batch_keypoint_scores, list)
  176. self.assertEqual(len(batch_keypoints), 1)
  177. self.assertEqual(len(batch_keypoint_scores), 1)
  178. keypoints, scores = self._sort_preds(batch_keypoints[0],
  179. batch_keypoint_scores[0],
  180. data['keypoints'])
  181. self.assertIsInstance(keypoints, np.ndarray)
  182. self.assertIsInstance(scores, np.ndarray)
  183. self.assertEqual(keypoints.shape, (2, 17, 2))
  184. self.assertEqual(scores.shape, (2, 17))
  185. self.assertTrue(np.allclose(keypoints, data['keypoints'], atol=4.0))