test_formatting.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. from unittest import TestCase
  4. import numpy as np
  5. import torch
  6. from mmengine.structures import InstanceData, PixelData
  7. from mmpose.datasets.transforms import PackPoseInputs
  8. from mmpose.structures import PoseDataSample
  9. class TestPackPoseInputs(TestCase):
  10. def setUp(self):
  11. """Setup some variables which are used in every test method.
  12. TestCase calls functions in this order: setUp() -> testMethod() ->
  13. tearDown() -> cleanUp()
  14. """
  15. # prepare dummy top-down data sample with COCO metainfo
  16. self.results_topdown = {
  17. 'img_id':
  18. 1,
  19. 'img_path':
  20. 'tests/data/coco/000000000785.jpg',
  21. 'id':
  22. 1,
  23. 'ori_shape': (425, 640),
  24. 'img_shape': (425, 640, 3),
  25. 'scale_factor':
  26. 2.0,
  27. 'flip':
  28. False,
  29. 'flip_direction':
  30. None,
  31. 'img':
  32. np.zeros((425, 640, 3), dtype=np.uint8),
  33. 'bbox':
  34. np.array([[0, 0, 100, 100]], dtype=np.float32),
  35. 'bbox_center':
  36. np.array([[50, 50]], dtype=np.float32),
  37. 'bbox_scale':
  38. np.array([[125, 125]], dtype=np.float32),
  39. 'bbox_rotation':
  40. np.array([45], dtype=np.float32),
  41. 'bbox_score':
  42. np.ones(1, dtype=np.float32),
  43. 'keypoints':
  44. np.random.randint(0, 100, (1, 17, 2)).astype(np.float32),
  45. 'keypoints_visible':
  46. np.full((1, 17), 1).astype(np.float32),
  47. 'keypoint_weights':
  48. np.full((1, 17), 1).astype(np.float32),
  49. 'heatmaps':
  50. np.random.random((17, 64, 48)).astype(np.float32),
  51. 'keypoint_labels':
  52. np.random.randint(0, 100, (1, 17, 2)).astype(np.float32),
  53. 'keypoint_x_labels':
  54. np.random.randint(0, 100, (1, 17, 2)).astype(np.float32),
  55. 'keypoint_y_labels':
  56. np.random.randint(0, 100, (1, 17, 2)).astype(np.float32),
  57. 'transformed_keypoints':
  58. np.random.randint(0, 100, (1, 17, 2)).astype(np.float32),
  59. }
  60. self.meta_keys = ('img_id', 'img_path', 'ori_shape', 'img_shape',
  61. 'scale_factor', 'flip', 'flip_direction')
  62. def test_transform(self):
  63. transform = PackPoseInputs(
  64. meta_keys=self.meta_keys, pack_transformed=True)
  65. results = transform(copy.deepcopy(self.results_topdown))
  66. self.assertIn('transformed_keypoints',
  67. results['data_samples'].gt_instances)
  68. transform = PackPoseInputs(meta_keys=self.meta_keys)
  69. results = transform(copy.deepcopy(self.results_topdown))
  70. self.assertIn('inputs', results)
  71. self.assertIsInstance(results['inputs'], torch.Tensor)
  72. self.assertEqual(results['inputs'].shape, (3, 425, 640))
  73. self.assertIn('data_samples', results)
  74. self.assertIsInstance(results['data_samples'], PoseDataSample)
  75. self.assertIsInstance(results['data_samples'].gt_instances,
  76. InstanceData)
  77. self.assertIsInstance(results['data_samples'].gt_fields, PixelData)
  78. self.assertEqual(len(results['data_samples'].gt_instances), 1)
  79. self.assertIsInstance(results['data_samples'].gt_fields.heatmaps,
  80. torch.Tensor)
  81. self.assertNotIn('transformed_keypoints',
  82. results['data_samples'].gt_instances)
  83. # test when results['img'] is sequence of frames
  84. results = copy.deepcopy(self.results_topdown)
  85. len_seq = 5
  86. results['img'] = [
  87. np.random.randint(0, 255, (425, 640, 3), dtype=np.uint8)
  88. for _ in range(len_seq)
  89. ]
  90. results = transform(results)
  91. self.assertIn('inputs', results)
  92. self.assertIsInstance(results['inputs'], torch.Tensor)
  93. # translate into 4-dim tensor: [len_seq, c, h, w]
  94. self.assertEqual(results['inputs'].shape, (len_seq, 3, 425, 640))
  95. def test_repr(self):
  96. transform = PackPoseInputs(meta_keys=self.meta_keys)
  97. self.assertEqual(
  98. repr(transform), f'PackPoseInputs(meta_keys={self.meta_keys})')