test_visualization_hook.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. import shutil
  4. import time
  5. from unittest import TestCase
  6. from unittest.mock import MagicMock
  7. import numpy as np
  8. from mmengine.structures import InstanceData
  9. from mmpose.engine.hooks import PoseVisualizationHook
  10. from mmpose.structures import PoseDataSample
  11. from mmpose.visualization import PoseLocalVisualizer
  12. def _rand_poses(num_boxes, h, w):
  13. center = np.random.rand(num_boxes, 2)
  14. offset = np.random.rand(num_boxes, 5, 2) / 2.0
  15. pose = center[:, None, :] + offset.clip(0, 1)
  16. pose[:, :, 0] *= w
  17. pose[:, :, 1] *= h
  18. return pose
  19. class TestVisualizationHook(TestCase):
  20. def setUp(self) -> None:
  21. PoseLocalVisualizer.get_instance('test_visualization_hook')
  22. data_sample = PoseDataSample()
  23. data_sample.set_metainfo({
  24. 'img_path':
  25. osp.join(
  26. osp.dirname(__file__), '../../data/coco/000000000785.jpg')
  27. })
  28. self.data_batch = {'data_samples': [data_sample] * 2}
  29. pred_instances = InstanceData()
  30. pred_instances.keypoints = _rand_poses(5, 10, 12)
  31. pred_instances.score = np.random.rand(5, 5)
  32. pred_det_data_sample = data_sample.clone()
  33. pred_det_data_sample.pred_instances = pred_instances
  34. self.outputs = [pred_det_data_sample] * 2
  35. def test_after_val_iter(self):
  36. runner = MagicMock()
  37. runner.iter = 1
  38. runner.val_evaluator.dataset_meta = dict()
  39. hook = PoseVisualizationHook(interval=1, enable=True)
  40. hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
  41. def test_after_test_iter(self):
  42. runner = MagicMock()
  43. runner.iter = 1
  44. hook = PoseVisualizationHook(enable=True)
  45. hook.after_test_iter(runner, 1, self.data_batch, self.outputs)
  46. self.assertEqual(hook._test_index, 2)
  47. # test
  48. timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
  49. out_dir = timestamp + '1'
  50. runner.work_dir = timestamp
  51. runner.timestamp = '1'
  52. hook = PoseVisualizationHook(enable=False, out_dir=out_dir)
  53. hook.after_test_iter(runner, 1, self.data_batch, self.outputs)
  54. self.assertTrue(not osp.exists(f'{timestamp}/1/{out_dir}'))
  55. hook = PoseVisualizationHook(enable=True, out_dir=out_dir)
  56. hook.after_test_iter(runner, 1, self.data_batch, self.outputs)
  57. self.assertTrue(osp.exists(f'{timestamp}/1/{out_dir}'))
  58. shutil.rmtree(f'{timestamp}')