123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import os.path as osp
- import warnings
- from typing import Optional, Sequence
- import mmcv
- from mmengine.fileio import get
- from mmengine.hooks import Hook
- from mmengine.runner import Runner
- from mmengine.utils import mkdir_or_exist
- from mmengine.visualization import Visualizer
- from mmdet.registry import HOOKS
- from mmdet.structures import DetDataSample
- @HOOKS.register_module()
- class DetVisualizationHook(Hook):
- """Detection Visualization Hook. Used to visualize validation and testing
- process prediction results.
- In the testing phase:
- 1. If ``show`` is True, it means that only the prediction results are
- visualized without storing data, so ``vis_backends`` needs to
- be excluded.
- 2. If ``test_out_dir`` is specified, it means that the prediction results
- need to be saved to ``test_out_dir``. In order to avoid vis_backends
- also storing data, so ``vis_backends`` needs to be excluded.
- 3. ``vis_backends`` takes effect if the user does not specify ``show``
- and `test_out_dir``. You can set ``vis_backends`` to WandbVisBackend or
- TensorboardVisBackend to store the prediction result in Wandb or
- Tensorboard.
- Args:
- draw (bool): whether to draw prediction results. If it is False,
- it means that no drawing will be done. Defaults to False.
- interval (int): The interval of visualization. Defaults to 50.
- score_thr (float): The threshold to visualize the bboxes
- and masks. Defaults to 0.3.
- show (bool): Whether to display the drawn image. Default to False.
- wait_time (float): The interval of show (s). Defaults to 0.
- test_out_dir (str, optional): directory where painted images
- will be saved in testing process.
- backend_args (dict, optional): Arguments to instantiate the
- corresponding backend. Defaults to None.
- """
- def __init__(self,
- draw: bool = False,
- interval: int = 50,
- score_thr: float = 0.3,
- show: bool = False,
- wait_time: float = 0.,
- test_out_dir: Optional[str] = None,
- backend_args: dict = None):
- self._visualizer: Visualizer = Visualizer.get_current_instance()
- self.interval = interval
- self.score_thr = score_thr
- self.show = show
- if self.show:
- # No need to think about vis backends.
- self._visualizer._vis_backends = {}
- warnings.warn('The show is True, it means that only '
- 'the prediction results are visualized '
- 'without storing data, so vis_backends '
- 'needs to be excluded.')
- self.wait_time = wait_time
- self.backend_args = backend_args
- self.draw = draw
- self.test_out_dir = test_out_dir
- self._test_index = 0
- def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
- outputs: Sequence[DetDataSample]) -> None:
- """Run after every ``self.interval`` validation iterations.
- Args:
- runner (:obj:`Runner`): The runner of the validation process.
- batch_idx (int): The index of the current batch in the val loop.
- data_batch (dict): Data from dataloader.
- outputs (Sequence[:obj:`DetDataSample`]]): A batch of data samples
- that contain annotations and predictions.
- """
- if self.draw is False:
- return
- # There is no guarantee that the same batch of images
- # is visualized for each evaluation.
- total_curr_iter = runner.iter + batch_idx
- # Visualize only the first data
- img_path = outputs[0].img_path
- img_bytes = get(img_path, backend_args=self.backend_args)
- img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
- if total_curr_iter % self.interval == 0:
- self._visualizer.add_datasample(
- osp.basename(img_path) if self.show else 'val_img',
- img,
- data_sample=outputs[0],
- show=self.show,
- wait_time=self.wait_time,
- pred_score_thr=self.score_thr,
- step=total_curr_iter)
- def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
- outputs: Sequence[DetDataSample]) -> None:
- """Run after every testing iterations.
- Args:
- runner (:obj:`Runner`): The runner of the testing process.
- batch_idx (int): The index of the current batch in the val loop.
- data_batch (dict): Data from dataloader.
- outputs (Sequence[:obj:`DetDataSample`]): A batch of data samples
- that contain annotations and predictions.
- """
- if self.draw is False:
- return
- if self.test_out_dir is not None:
- self.test_out_dir = osp.join(runner.work_dir, runner.timestamp,
- self.test_out_dir)
- mkdir_or_exist(self.test_out_dir)
- for data_sample in outputs:
- self._test_index += 1
- img_path = data_sample.img_path
- img_bytes = get(img_path, backend_args=self.backend_args)
- img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
- out_file = None
- if self.test_out_dir is not None:
- out_file = osp.basename(img_path)
- out_file = osp.join(self.test_out_dir, out_file)
- self._visualizer.add_datasample(
- osp.basename(img_path) if self.show else 'test_img',
- img,
- data_sample=data_sample,
- show=self.show,
- wait_time=self.wait_time,
- pred_score_thr=self.score_thr,
- out_file=out_file,
- step=self._test_index)
|