12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import os.path as osp
- import shutil
- import time
- from unittest import TestCase
- from unittest.mock import MagicMock
- import numpy as np
- from mmengine.structures import InstanceData
- from mmpose.engine.hooks import PoseVisualizationHook
- from mmpose.structures import PoseDataSample
- from mmpose.visualization import PoseLocalVisualizer
- def _rand_poses(num_boxes, h, w):
- center = np.random.rand(num_boxes, 2)
- offset = np.random.rand(num_boxes, 5, 2) / 2.0
- pose = center[:, None, :] + offset.clip(0, 1)
- pose[:, :, 0] *= w
- pose[:, :, 1] *= h
- return pose
- class TestVisualizationHook(TestCase):
- def setUp(self) -> None:
- PoseLocalVisualizer.get_instance('test_visualization_hook')
- data_sample = PoseDataSample()
- data_sample.set_metainfo({
- 'img_path':
- osp.join(
- osp.dirname(__file__), '../../data/coco/000000000785.jpg')
- })
- self.data_batch = {'data_samples': [data_sample] * 2}
- pred_instances = InstanceData()
- pred_instances.keypoints = _rand_poses(5, 10, 12)
- pred_instances.score = np.random.rand(5, 5)
- pred_det_data_sample = data_sample.clone()
- pred_det_data_sample.pred_instances = pred_instances
- self.outputs = [pred_det_data_sample] * 2
- def test_after_val_iter(self):
- runner = MagicMock()
- runner.iter = 1
- runner.val_evaluator.dataset_meta = dict()
- hook = PoseVisualizationHook(interval=1, enable=True)
- hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
- def test_after_test_iter(self):
- runner = MagicMock()
- runner.iter = 1
- hook = PoseVisualizationHook(enable=True)
- hook.after_test_iter(runner, 1, self.data_batch, self.outputs)
- self.assertEqual(hook._test_index, 2)
- # test
- timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
- out_dir = timestamp + '1'
- runner.work_dir = timestamp
- runner.timestamp = '1'
- hook = PoseVisualizationHook(enable=False, out_dir=out_dir)
- hook.after_test_iter(runner, 1, self.data_batch, self.outputs)
- self.assertTrue(not osp.exists(f'{timestamp}/1/{out_dir}'))
- hook = PoseVisualizationHook(enable=True, out_dir=out_dir)
- hook.after_test_iter(runner, 1, self.data_batch, self.outputs)
- self.assertTrue(osp.exists(f'{timestamp}/1/{out_dir}'))
- shutil.rmtree(f'{timestamp}')
|