test_decoupled_heatmap.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import numpy as np
  4. from mmpose.codecs import DecoupledHeatmap
  5. from mmpose.registry import KEYPOINT_CODECS
  6. from mmpose.testing import get_coco_sample
  7. class TestDecoupledHeatmap(TestCase):
  8. def setUp(self) -> None:
  9. pass
  10. def _make_multi_instance_data(self, data):
  11. bbox = data['bbox'].reshape(-1, 2, 2)
  12. keypoints = data['keypoints']
  13. keypoints_visible = data['keypoints_visible']
  14. keypoints_visible[..., 0] = 0
  15. offset = keypoints.max(axis=1, keepdims=True)
  16. bbox_outside = bbox - offset
  17. keypoints_outside = keypoints - offset
  18. keypoints_outside_visible = np.zeros(keypoints_visible.shape)
  19. bbox_overlap = bbox.mean(
  20. axis=1, keepdims=True) + 0.8 * (
  21. bbox - bbox.mean(axis=1, keepdims=True))
  22. keypoint_overlap = keypoints.mean(
  23. axis=1, keepdims=True) + 0.8 * (
  24. keypoints - keypoints.mean(axis=1, keepdims=True))
  25. keypoint_overlap_visible = keypoints_visible
  26. data['bbox'] = np.concatenate((bbox, bbox_outside, bbox_overlap),
  27. axis=0)
  28. data['keypoints'] = np.concatenate(
  29. (keypoints, keypoints_outside, keypoint_overlap), axis=0)
  30. data['keypoints_visible'] = np.concatenate(
  31. (keypoints_visible, keypoints_outside_visible,
  32. keypoint_overlap_visible),
  33. axis=0)
  34. return data
  35. def test_build(self):
  36. cfg = dict(
  37. type='DecoupledHeatmap',
  38. input_size=(512, 512),
  39. heatmap_size=(128, 128),
  40. )
  41. codec = KEYPOINT_CODECS.build(cfg)
  42. self.assertIsInstance(codec, DecoupledHeatmap)
  43. def test_encode(self):
  44. data = get_coco_sample(img_shape=(512, 512), num_instances=1)
  45. data['bbox'] = np.tile(data['bbox'], 2).reshape(-1, 4, 2)
  46. data['bbox'][:, 1:3, 0] = data['bbox'][:, 0:2, 0]
  47. data = self._make_multi_instance_data(data)
  48. codec = DecoupledHeatmap(
  49. input_size=(512, 512),
  50. heatmap_size=(128, 128),
  51. )
  52. print(data['bbox'].shape)
  53. encoded = codec.encode(
  54. data['keypoints'], data['keypoints_visible'], bbox=data['bbox'])
  55. heatmaps = encoded['heatmaps']
  56. instance_heatmaps = encoded['instance_heatmaps']
  57. keypoint_weights = encoded['keypoint_weights']
  58. instance_coords = encoded['instance_coords']
  59. self.assertEqual(heatmaps.shape, (18, 128, 128))
  60. self.assertEqual(keypoint_weights.shape, (2, 17))
  61. self.assertEqual(instance_heatmaps.shape, (34, 128, 128))
  62. self.assertEqual(instance_coords.shape, (2, 2))
  63. # without bbox
  64. encoded = codec.encode(
  65. data['keypoints'], data['keypoints_visible'], bbox=None)
  66. heatmaps = encoded['heatmaps']
  67. instance_heatmaps = encoded['instance_heatmaps']
  68. keypoint_weights = encoded['keypoint_weights']
  69. instance_coords = encoded['instance_coords']
  70. self.assertEqual(heatmaps.shape, (18, 128, 128))
  71. self.assertEqual(keypoint_weights.shape, (2, 17))
  72. self.assertEqual(instance_heatmaps.shape, (34, 128, 128))
  73. self.assertEqual(instance_coords.shape, (2, 2))
  74. # root_type
  75. with self.assertRaises(ValueError):
  76. codec = DecoupledHeatmap(
  77. input_size=(512, 512),
  78. heatmap_size=(128, 128),
  79. root_type='box_center',
  80. )
  81. encoded = codec.encode(
  82. data['keypoints'],
  83. data['keypoints_visible'],
  84. bbox=data['bbox'])
  85. codec = DecoupledHeatmap(
  86. input_size=(512, 512),
  87. heatmap_size=(128, 128),
  88. root_type='bbox_center',
  89. )
  90. encoded = codec.encode(
  91. data['keypoints'], data['keypoints_visible'], bbox=data['bbox'])
  92. heatmaps = encoded['heatmaps']
  93. instance_heatmaps = encoded['instance_heatmaps']
  94. keypoint_weights = encoded['keypoint_weights']
  95. instance_coords = encoded['instance_coords']
  96. self.assertEqual(heatmaps.shape, (18, 128, 128))
  97. self.assertEqual(keypoint_weights.shape, (2, 17))
  98. self.assertEqual(instance_heatmaps.shape, (34, 128, 128))
  99. self.assertEqual(instance_coords.shape, (2, 2))
  100. def test_decode(self):
  101. data = get_coco_sample(img_shape=(512, 512), num_instances=2)
  102. data['bbox'] = np.tile(data['bbox'], 2).reshape(-1, 4, 2)
  103. data['bbox'][:, 1:3, 0] = data['bbox'][:, 0:2, 0]
  104. codec = DecoupledHeatmap(
  105. input_size=(512, 512),
  106. heatmap_size=(128, 128),
  107. )
  108. encoded = codec.encode(
  109. data['keypoints'], data['keypoints_visible'], bbox=data['bbox'])
  110. instance_heatmaps = encoded['instance_heatmaps'].reshape(
  111. encoded['instance_coords'].shape[0], -1,
  112. *encoded['instance_heatmaps'].shape[-2:])
  113. instance_scores = np.ones(encoded['instance_coords'].shape[0])
  114. decoded = codec.decode(instance_heatmaps, instance_scores[:, None])
  115. keypoints, keypoint_scores = decoded
  116. self.assertEqual(keypoints.shape, (2, 17, 2))
  117. self.assertEqual(keypoint_scores.shape, (2, 17))
  118. def test_cicular_verification(self):
  119. data = get_coco_sample(img_shape=(512, 512), num_instances=1)
  120. data['bbox'] = np.tile(data['bbox'], 2).reshape(-1, 4, 2)
  121. data['bbox'][:, 1:3, 0] = data['bbox'][:, 0:2, 0]
  122. codec = DecoupledHeatmap(
  123. input_size=(512, 512),
  124. heatmap_size=(128, 128),
  125. )
  126. encoded = codec.encode(
  127. data['keypoints'], data['keypoints_visible'], bbox=data['bbox'])
  128. instance_heatmaps = encoded['instance_heatmaps'].reshape(
  129. encoded['instance_coords'].shape[0], -1,
  130. *encoded['instance_heatmaps'].shape[-2:])
  131. instance_scores = np.ones(encoded['instance_coords'].shape[0])
  132. decoded = codec.decode(instance_heatmaps, instance_scores[:, None])
  133. keypoints, _ = decoded
  134. keypoints += 1.5
  135. self.assertTrue(np.allclose(keypoints, data['keypoints'], atol=5.))