test_visualization_hook.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. import shutil
  4. import time
  5. from unittest import TestCase
  6. from unittest.mock import Mock
  7. import torch
  8. from mmengine.structures import InstanceData
  9. from mmdet.engine.hooks import DetVisualizationHook
  10. from mmdet.structures import DetDataSample
  11. from mmdet.visualization import DetLocalVisualizer
  12. def _rand_bboxes(num_boxes, h, w):
  13. cx, cy, bw, bh = torch.rand(num_boxes, 4).T
  14. tl_x = ((cx * w) - (w * bw / 2)).clamp(0, w)
  15. tl_y = ((cy * h) - (h * bh / 2)).clamp(0, h)
  16. br_x = ((cx * w) + (w * bw / 2)).clamp(0, w)
  17. br_y = ((cy * h) + (h * bh / 2)).clamp(0, h)
  18. bboxes = torch.stack([tl_x, tl_y, br_x, br_y], dim=0).T
  19. return bboxes
  20. class TestVisualizationHook(TestCase):
  21. def setUp(self) -> None:
  22. DetLocalVisualizer.get_instance('current_visualizer')
  23. pred_instances = InstanceData()
  24. pred_instances.bboxes = _rand_bboxes(5, 10, 12)
  25. pred_instances.labels = torch.randint(0, 2, (5, ))
  26. pred_instances.scores = torch.rand((5, ))
  27. pred_det_data_sample = DetDataSample()
  28. pred_det_data_sample.set_metainfo({
  29. 'img_path':
  30. osp.join(osp.dirname(__file__), '../../data/color.jpg')
  31. })
  32. pred_det_data_sample.pred_instances = pred_instances
  33. self.outputs = [pred_det_data_sample] * 2
  34. def test_after_val_iter(self):
  35. runner = Mock()
  36. runner.iter = 1
  37. hook = DetVisualizationHook()
  38. hook.after_val_iter(runner, 1, {}, self.outputs)
  39. def test_after_test_iter(self):
  40. runner = Mock()
  41. runner.iter = 1
  42. hook = DetVisualizationHook(draw=True)
  43. hook.after_test_iter(runner, 1, {}, self.outputs)
  44. self.assertEqual(hook._test_index, 2)
  45. # test
  46. timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
  47. test_out_dir = timestamp + '1'
  48. runner.work_dir = timestamp
  49. runner.timestamp = '1'
  50. hook = DetVisualizationHook(draw=False, test_out_dir=test_out_dir)
  51. hook.after_test_iter(runner, 1, {}, self.outputs)
  52. self.assertTrue(not osp.exists(f'{timestamp}/1/{test_out_dir}'))
  53. hook = DetVisualizationHook(draw=True, test_out_dir=test_out_dir)
  54. hook.after_test_iter(runner, 1, {}, self.outputs)
  55. self.assertTrue(osp.exists(f'{timestamp}/1/{test_out_dir}'))
  56. shutil.rmtree(f'{timestamp}')