# Copyright (c) OpenMMLab. All rights reserved. import os from unittest import TestCase import cv2 import numpy as np import torch from mmengine.structures import InstanceData, PixelData from mmpose.structures import PoseDataSample from mmpose.visualization import PoseLocalVisualizer class TestPoseLocalVisualizer(TestCase): def setUp(self): self.visualizer = PoseLocalVisualizer(show_keypoint_weight=True) def _get_dataset_meta(self): # None: kpt or link is hidden pose_kpt_color = [None] + [(127, 127, 127)] * 2 + ['red'] pose_link_color = [(127, 127, 127)] * 2 + [None] skeleton_links = [[0, 1], [1, 2], [2, 3]] return { 'keypoint_colors': pose_kpt_color, 'skeleton_link_colors': pose_link_color, 'skeleton_links': skeleton_links } def test_set_dataset_meta(self): dataset_meta = self._get_dataset_meta() self.visualizer.set_dataset_meta(dataset_meta) self.assertEqual(len(self.visualizer.kpt_color), 4) self.assertEqual(self.visualizer.kpt_color[-1], 'red') self.assertListEqual(self.visualizer.skeleton[-1], [2, 3]) self.visualizer.dataset_meta = None self.visualizer.set_dataset_meta(dataset_meta) self.assertIsNotNone(self.visualizer.dataset_meta) def test_add_datasample(self): h, w = 100, 100 image = np.zeros((h, w, 3), dtype=np.uint8) out_file = 'out_file.jpg' dataset_meta = self._get_dataset_meta() self.visualizer.set_dataset_meta(dataset_meta) # setting keypoints gt_instances = InstanceData() gt_instances.keypoints = np.array([[[1, 1], [20, 20], [40, 40], [80, 80]]], dtype=np.float32) # setting bounding box gt_instances.bboxes = np.array([[20, 30, 50, 70]]) # setting heatmap heatmap = torch.randn(10, 100, 100) * 0.05 for i in range(10): heatmap[i][i * 10:(i + 1) * 10, i * 10:(i + 1) * 10] += 5 gt_heatmap = PixelData() gt_heatmap.heatmaps = heatmap # test gt_sample pred_pose_data_sample = PoseDataSample() pred_pose_data_sample.gt_instances = gt_instances pred_pose_data_sample.gt_fields = gt_heatmap pred_instances = gt_instances.clone() pred_instances.scores = np.array([[0.9, 0.4, 1.7, -0.2]], dtype=np.float32) pred_pose_data_sample.pred_instances = pred_instances self.visualizer.add_datasample( 'image', image, data_sample=pred_pose_data_sample, draw_bbox=True, out_file=out_file) self._assert_image_and_shape(out_file, (h, w * 2, 3)) self.visualizer.show_keypoint_weight = False self.visualizer.add_datasample( 'image', image, data_sample=pred_pose_data_sample, draw_pred=False, draw_heatmap=True, out_file=out_file) self._assert_image_and_shape(out_file, ((h * 2), w, 3)) self.visualizer.add_datasample( 'image', image, data_sample=pred_pose_data_sample, draw_heatmap=True, out_file=out_file) self._assert_image_and_shape(out_file, ((h * 2), (w * 2), 3)) def test_simcc_visualization(self): img = np.zeros((512, 512, 3), dtype=np.uint8) heatmap = torch.randn([17, 512, 512]) pixelData = PixelData() pixelData.heatmaps = heatmap self.visualizer._draw_instance_xy_heatmap(pixelData, img, 10) def _assert_image_and_shape(self, out_file, out_shape): self.assertTrue(os.path.exists(out_file)) drawn_img = cv2.imread(out_file) self.assertTupleEqual(drawn_img.shape, out_shape) os.remove(out_file)