test_pose_data_sample.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import numpy as np
  4. import torch
  5. from mmengine.structures import InstanceData, PixelData
  6. from mmpose.structures import MultilevelPixelData, PoseDataSample
  7. class TestPoseDataSample(TestCase):
  8. def get_pose_data_sample(self, multilevel: bool = False):
  9. # meta
  10. pose_meta = dict(
  11. img_shape=(600, 900), # [h, w, c]
  12. crop_size=(256, 192), # [h, w]
  13. heatmap_size=(64, 48), # [h, w]
  14. )
  15. # gt_instances
  16. gt_instances = InstanceData()
  17. gt_instances.bboxes = torch.rand(1, 4)
  18. gt_instances.keypoints = torch.rand(1, 17, 2)
  19. gt_instances.keypoints_visible = torch.rand(1, 17)
  20. # pred_instances
  21. pred_instances = InstanceData()
  22. pred_instances.keypoints = torch.rand(1, 17, 2)
  23. pred_instances.keypoint_scores = torch.rand(1, 17)
  24. # gt_fields
  25. if multilevel:
  26. # generate multilevel gt_fields
  27. metainfo = dict(num_keypoints=17)
  28. sizes = [(64, 48), (32, 24), (16, 12)]
  29. heatmaps = [np.random.rand(17, h, w) for h, w in sizes]
  30. masks = [torch.rand(1, h, w) for h, w in sizes]
  31. gt_fields = MultilevelPixelData(
  32. metainfo=metainfo, heatmaps=heatmaps, masks=masks)
  33. else:
  34. gt_fields = PixelData()
  35. gt_fields.heatmaps = torch.rand(17, 64, 48)
  36. # pred_fields
  37. pred_fields = PixelData()
  38. pred_fields.heatmaps = torch.rand(17, 64, 48)
  39. data_sample = PoseDataSample(
  40. gt_instances=gt_instances,
  41. pred_instances=pred_instances,
  42. gt_fields=gt_fields,
  43. pred_fields=pred_fields,
  44. metainfo=pose_meta)
  45. return data_sample
  46. @staticmethod
  47. def _equal(x, y):
  48. if type(x) != type(y):
  49. return False
  50. if isinstance(x, torch.Tensor):
  51. return torch.allclose(x, y)
  52. elif isinstance(x, np.ndarray):
  53. return np.allclose(x, y)
  54. else:
  55. return x == y
  56. def test_init(self):
  57. data_sample = self.get_pose_data_sample()
  58. self.assertIn('img_shape', data_sample)
  59. self.assertTrue(len(data_sample.gt_instances) == 1)
  60. def test_setter(self):
  61. data_sample = self.get_pose_data_sample()
  62. # test gt_instances
  63. data_sample.gt_instances = InstanceData()
  64. # test gt_fields
  65. data_sample.gt_fields = PixelData()
  66. # test multilevel gt_fields
  67. data_sample = self.get_pose_data_sample(multilevel=True)
  68. data_sample.gt_fields = MultilevelPixelData()
  69. # test pred_instances as pytorch tensor
  70. pred_instances_data = dict(
  71. keypoints=torch.rand(1, 17, 2), scores=torch.rand(1, 17, 1))
  72. data_sample.pred_instances = InstanceData(**pred_instances_data)
  73. self.assertTrue(
  74. self._equal(data_sample.pred_instances.keypoints,
  75. pred_instances_data['keypoints']))
  76. self.assertTrue(
  77. self._equal(data_sample.pred_instances.scores,
  78. pred_instances_data['scores']))
  79. # test pred_fields as numpy array
  80. pred_fields_data = dict(heatmaps=np.random.rand(17, 64, 48))
  81. data_sample.pred_fields = PixelData(**pred_fields_data)
  82. self.assertTrue(
  83. self._equal(data_sample.pred_fields.heatmaps,
  84. pred_fields_data['heatmaps']))
  85. # test to_tensor
  86. data_sample = data_sample.to_tensor()
  87. self.assertTrue(
  88. self._equal(data_sample.pred_fields.heatmaps,
  89. torch.from_numpy(pred_fields_data['heatmaps'])))
  90. def test_deleter(self):
  91. data_sample = self.get_pose_data_sample()
  92. for key in [
  93. 'gt_instances',
  94. 'pred_instances',
  95. 'gt_fields',
  96. 'pred_fields',
  97. ]:
  98. self.assertIn(key, data_sample)
  99. exec(f'del data_sample.{key}')
  100. self.assertNotIn(key, data_sample)