test_pose_visualizer.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os
  3. from unittest import TestCase
  4. import cv2
  5. import numpy as np
  6. import torch
  7. from mmengine.structures import InstanceData, PixelData
  8. from mmpose.structures import PoseDataSample
  9. from mmpose.visualization import PoseLocalVisualizer
  10. class TestPoseLocalVisualizer(TestCase):
  11. def setUp(self):
  12. self.visualizer = PoseLocalVisualizer(show_keypoint_weight=True)
  13. def _get_dataset_meta(self):
  14. # None: kpt or link is hidden
  15. pose_kpt_color = [None] + [(127, 127, 127)] * 2 + ['red']
  16. pose_link_color = [(127, 127, 127)] * 2 + [None]
  17. skeleton_links = [[0, 1], [1, 2], [2, 3]]
  18. return {
  19. 'keypoint_colors': pose_kpt_color,
  20. 'skeleton_link_colors': pose_link_color,
  21. 'skeleton_links': skeleton_links
  22. }
  23. def test_set_dataset_meta(self):
  24. dataset_meta = self._get_dataset_meta()
  25. self.visualizer.set_dataset_meta(dataset_meta)
  26. self.assertEqual(len(self.visualizer.kpt_color), 4)
  27. self.assertEqual(self.visualizer.kpt_color[-1], 'red')
  28. self.assertListEqual(self.visualizer.skeleton[-1], [2, 3])
  29. self.visualizer.dataset_meta = None
  30. self.visualizer.set_dataset_meta(dataset_meta)
  31. self.assertIsNotNone(self.visualizer.dataset_meta)
  32. def test_add_datasample(self):
  33. h, w = 100, 100
  34. image = np.zeros((h, w, 3), dtype=np.uint8)
  35. out_file = 'out_file.jpg'
  36. dataset_meta = self._get_dataset_meta()
  37. self.visualizer.set_dataset_meta(dataset_meta)
  38. # setting keypoints
  39. gt_instances = InstanceData()
  40. gt_instances.keypoints = np.array([[[1, 1], [20, 20], [40, 40],
  41. [80, 80]]],
  42. dtype=np.float32)
  43. # setting bounding box
  44. gt_instances.bboxes = np.array([[20, 30, 50, 70]])
  45. # setting heatmap
  46. heatmap = torch.randn(10, 100, 100) * 0.05
  47. for i in range(10):
  48. heatmap[i][i * 10:(i + 1) * 10, i * 10:(i + 1) * 10] += 5
  49. gt_heatmap = PixelData()
  50. gt_heatmap.heatmaps = heatmap
  51. # test gt_sample
  52. pred_pose_data_sample = PoseDataSample()
  53. pred_pose_data_sample.gt_instances = gt_instances
  54. pred_pose_data_sample.gt_fields = gt_heatmap
  55. pred_instances = gt_instances.clone()
  56. pred_instances.scores = np.array([[0.9, 0.4, 1.7, -0.2]],
  57. dtype=np.float32)
  58. pred_pose_data_sample.pred_instances = pred_instances
  59. self.visualizer.add_datasample(
  60. 'image',
  61. image,
  62. data_sample=pred_pose_data_sample,
  63. draw_bbox=True,
  64. out_file=out_file)
  65. self._assert_image_and_shape(out_file, (h, w * 2, 3))
  66. self.visualizer.show_keypoint_weight = False
  67. self.visualizer.add_datasample(
  68. 'image',
  69. image,
  70. data_sample=pred_pose_data_sample,
  71. draw_pred=False,
  72. draw_heatmap=True,
  73. out_file=out_file)
  74. self._assert_image_and_shape(out_file, ((h * 2), w, 3))
  75. self.visualizer.add_datasample(
  76. 'image',
  77. image,
  78. data_sample=pred_pose_data_sample,
  79. draw_heatmap=True,
  80. out_file=out_file)
  81. self._assert_image_and_shape(out_file, ((h * 2), (w * 2), 3))
  82. def test_simcc_visualization(self):
  83. img = np.zeros((512, 512, 3), dtype=np.uint8)
  84. heatmap = torch.randn([17, 512, 512])
  85. pixelData = PixelData()
  86. pixelData.heatmaps = heatmap
  87. self.visualizer._draw_instance_xy_heatmap(pixelData, img, 10)
  88. def _assert_image_and_shape(self, out_file, out_shape):
  89. self.assertTrue(os.path.exists(out_file))
  90. drawn_img = cv2.imread(out_file)
  91. self.assertTupleEqual(drawn_img.shape, out_shape)
  92. os.remove(out_file)