123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import os
- from unittest import TestCase
- import cv2
- import numpy as np
- import torch
- from mmengine.structures import InstanceData, PixelData
- from mmpose.structures import PoseDataSample
- from mmpose.visualization import PoseLocalVisualizer
- class TestPoseLocalVisualizer(TestCase):
- def setUp(self):
- self.visualizer = PoseLocalVisualizer(show_keypoint_weight=True)
- def _get_dataset_meta(self):
- # None: kpt or link is hidden
- pose_kpt_color = [None] + [(127, 127, 127)] * 2 + ['red']
- pose_link_color = [(127, 127, 127)] * 2 + [None]
- skeleton_links = [[0, 1], [1, 2], [2, 3]]
- return {
- 'keypoint_colors': pose_kpt_color,
- 'skeleton_link_colors': pose_link_color,
- 'skeleton_links': skeleton_links
- }
- def test_set_dataset_meta(self):
- dataset_meta = self._get_dataset_meta()
- self.visualizer.set_dataset_meta(dataset_meta)
- self.assertEqual(len(self.visualizer.kpt_color), 4)
- self.assertEqual(self.visualizer.kpt_color[-1], 'red')
- self.assertListEqual(self.visualizer.skeleton[-1], [2, 3])
- self.visualizer.dataset_meta = None
- self.visualizer.set_dataset_meta(dataset_meta)
- self.assertIsNotNone(self.visualizer.dataset_meta)
- def test_add_datasample(self):
- h, w = 100, 100
- image = np.zeros((h, w, 3), dtype=np.uint8)
- out_file = 'out_file.jpg'
- dataset_meta = self._get_dataset_meta()
- self.visualizer.set_dataset_meta(dataset_meta)
- # setting keypoints
- gt_instances = InstanceData()
- gt_instances.keypoints = np.array([[[1, 1], [20, 20], [40, 40],
- [80, 80]]],
- dtype=np.float32)
- # setting bounding box
- gt_instances.bboxes = np.array([[20, 30, 50, 70]])
- # setting heatmap
- heatmap = torch.randn(10, 100, 100) * 0.05
- for i in range(10):
- heatmap[i][i * 10:(i + 1) * 10, i * 10:(i + 1) * 10] += 5
- gt_heatmap = PixelData()
- gt_heatmap.heatmaps = heatmap
- # test gt_sample
- pred_pose_data_sample = PoseDataSample()
- pred_pose_data_sample.gt_instances = gt_instances
- pred_pose_data_sample.gt_fields = gt_heatmap
- pred_instances = gt_instances.clone()
- pred_instances.scores = np.array([[0.9, 0.4, 1.7, -0.2]],
- dtype=np.float32)
- pred_pose_data_sample.pred_instances = pred_instances
- self.visualizer.add_datasample(
- 'image',
- image,
- data_sample=pred_pose_data_sample,
- draw_bbox=True,
- out_file=out_file)
- self._assert_image_and_shape(out_file, (h, w * 2, 3))
- self.visualizer.show_keypoint_weight = False
- self.visualizer.add_datasample(
- 'image',
- image,
- data_sample=pred_pose_data_sample,
- draw_pred=False,
- draw_heatmap=True,
- out_file=out_file)
- self._assert_image_and_shape(out_file, ((h * 2), w, 3))
- self.visualizer.add_datasample(
- 'image',
- image,
- data_sample=pred_pose_data_sample,
- draw_heatmap=True,
- out_file=out_file)
- self._assert_image_and_shape(out_file, ((h * 2), (w * 2), 3))
- def test_simcc_visualization(self):
- img = np.zeros((512, 512, 3), dtype=np.uint8)
- heatmap = torch.randn([17, 512, 512])
- pixelData = PixelData()
- pixelData.heatmaps = heatmap
- self.visualizer._draw_instance_xy_heatmap(pixelData, img, 10)
- def _assert_image_and_shape(self, out_file, out_shape):
- self.assertTrue(os.path.exists(out_file))
- drawn_img = cv2.imread(out_file)
- self.assertTupleEqual(drawn_img.shape, out_shape)
- os.remove(out_file)
|