test_converting.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. from mmpose.datasets.transforms import KeypointConverter
  4. from mmpose.testing import get_coco_sample
  5. class TestKeypointConverter(TestCase):
  6. def setUp(self):
  7. # prepare dummy bottom-up data sample with COCO metainfo
  8. self.data_info = get_coco_sample(
  9. img_shape=(240, 320), num_instances=4, with_bbox_cs=True)
  10. def test_transform(self):
  11. # 1-to-1 mapping
  12. mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
  13. transform = KeypointConverter(num_keypoints=5, mapping=mapping)
  14. results = transform(self.data_info.copy())
  15. # check shape
  16. self.assertEqual(results['keypoints'].shape[0],
  17. self.data_info['keypoints'].shape[0])
  18. self.assertEqual(results['keypoints'].shape[1], 5)
  19. self.assertEqual(results['keypoints'].shape[2], 2)
  20. self.assertEqual(results['keypoints_visible'].shape[0],
  21. self.data_info['keypoints_visible'].shape[0])
  22. self.assertEqual(results['keypoints_visible'].shape[1], 5)
  23. # check value
  24. for source_index, target_index in mapping:
  25. self.assertTrue((results['keypoints'][:, target_index] ==
  26. self.data_info['keypoints'][:,
  27. source_index]).all())
  28. self.assertTrue(
  29. (results['keypoints_visible'][:, target_index] ==
  30. self.data_info['keypoints_visible'][:, source_index]).all())
  31. # 2-to-1 mapping
  32. mapping = [((3, 5), 0), (6, 1), (16, 2), (5, 3)]
  33. transform = KeypointConverter(num_keypoints=5, mapping=mapping)
  34. results = transform(self.data_info.copy())
  35. # check shape
  36. self.assertEqual(results['keypoints'].shape[0],
  37. self.data_info['keypoints'].shape[0])
  38. self.assertEqual(results['keypoints'].shape[1], 5)
  39. self.assertEqual(results['keypoints'].shape[2], 2)
  40. self.assertEqual(results['keypoints_visible'].shape[0],
  41. self.data_info['keypoints_visible'].shape[0])
  42. self.assertEqual(results['keypoints_visible'].shape[1], 5)
  43. # check value
  44. for source_index, target_index in mapping:
  45. if isinstance(source_index, tuple):
  46. source_index, source_index2 = source_index
  47. self.assertTrue(
  48. (results['keypoints'][:, target_index] == 0.5 *
  49. (self.data_info['keypoints'][:, source_index] +
  50. self.data_info['keypoints'][:, source_index2])).all())
  51. self.assertTrue(
  52. (results['keypoints_visible'][:, target_index] ==
  53. self.data_info['keypoints_visible'][:, source_index] *
  54. self.data_info['keypoints_visible'][:,
  55. source_index2]).all())
  56. else:
  57. self.assertTrue(
  58. (results['keypoints'][:, target_index] ==
  59. self.data_info['keypoints'][:, source_index]).all())
  60. self.assertTrue(
  61. (results['keypoints_visible'][:, target_index] ==
  62. self.data_info['keypoints_visible'][:,
  63. source_index]).all())