123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Dict, List, Optional, Tuple, Union
- import cv2
- import mmcv
- import numpy as np
- import torch
- from mmengine.dist import master_only
- from mmengine.structures import InstanceData, PixelData
- from mmengine.visualization import Visualizer
- from ..evaluation import INSTANCE_OFFSET
- from ..registry import VISUALIZERS
- from ..structures import DetDataSample
- from ..structures.mask import BitmapMasks, PolygonMasks, bitmap_to_polygon
- from .palette import _get_adaptive_scales, get_palette, jitter_color
- @VISUALIZERS.register_module()
- class DetLocalVisualizer(Visualizer):
- """MMDetection Local Visualizer.
- Args:
- name (str): Name of the instance. Defaults to 'visualizer'.
- image (np.ndarray, optional): the origin image to draw. The format
- should be RGB. Defaults to None.
- vis_backends (list, optional): Visual backend config list.
- Defaults to None.
- save_dir (str, optional): Save file dir for all storage backends.
- If it is None, the backend storage will not save any data.
- bbox_color (str, tuple(int), optional): Color of bbox lines.
- The tuple of color should be in BGR order. Defaults to None.
- text_color (str, tuple(int), optional): Color of texts.
- The tuple of color should be in BGR order.
- Defaults to (200, 200, 200).
- mask_color (str, tuple(int), optional): Color of masks.
- The tuple of color should be in BGR order.
- Defaults to None.
- line_width (int, float): The linewidth of lines.
- Defaults to 3.
- alpha (int, float): The transparency of bboxes or mask.
- Defaults to 0.8.
- Examples:
- >>> import numpy as np
- >>> import torch
- >>> from mmengine.structures import InstanceData
- >>> from mmdet.structures import DetDataSample
- >>> from mmdet.visualization import DetLocalVisualizer
- >>> det_local_visualizer = DetLocalVisualizer()
- >>> image = np.random.randint(0, 256,
- ... size=(10, 12, 3)).astype('uint8')
- >>> gt_instances = InstanceData()
- >>> gt_instances.bboxes = torch.Tensor([[1, 2, 2, 5]])
- >>> gt_instances.labels = torch.randint(0, 2, (1,))
- >>> gt_det_data_sample = DetDataSample()
- >>> gt_det_data_sample.gt_instances = gt_instances
- >>> det_local_visualizer.add_datasample('image', image,
- ... gt_det_data_sample)
- >>> det_local_visualizer.add_datasample(
- ... 'image', image, gt_det_data_sample,
- ... out_file='out_file.jpg')
- >>> det_local_visualizer.add_datasample(
- ... 'image', image, gt_det_data_sample,
- ... show=True)
- >>> pred_instances = InstanceData()
- >>> pred_instances.bboxes = torch.Tensor([[2, 4, 4, 8]])
- >>> pred_instances.labels = torch.randint(0, 2, (1,))
- >>> pred_det_data_sample = DetDataSample()
- >>> pred_det_data_sample.pred_instances = pred_instances
- >>> det_local_visualizer.add_datasample('image', image,
- ... gt_det_data_sample,
- ... pred_det_data_sample)
- """
- def __init__(self,
- name: str = 'visualizer',
- image: Optional[np.ndarray] = None,
- vis_backends: Optional[Dict] = None,
- save_dir: Optional[str] = None,
- bbox_color: Optional[Union[str, Tuple[int]]] = None,
- text_color: Optional[Union[str,
- Tuple[int]]] = (200, 200, 200),
- mask_color: Optional[Union[str, Tuple[int]]] = None,
- line_width: Union[int, float] = 3,
- alpha: float = 0.8) -> None:
- super().__init__(
- name=name,
- image=image,
- vis_backends=vis_backends,
- save_dir=save_dir)
- self.bbox_color = bbox_color
- self.text_color = text_color
- self.mask_color = mask_color
- self.line_width = line_width
- self.alpha = alpha
- # Set default value. When calling
- # `DetLocalVisualizer().dataset_meta=xxx`,
- # it will override the default value.
- self.dataset_meta = {}
- def _draw_instances(self, image: np.ndarray, instances: ['InstanceData'],
- classes: Optional[List[str]],
- palette: Optional[List[tuple]]) -> np.ndarray:
- """Draw instances of GT or prediction.
- Args:
- image (np.ndarray): The image to draw.
- instances (:obj:`InstanceData`): Data structure for
- instance-level annotations or predictions.
- classes (List[str], optional): Category information.
- palette (List[tuple], optional): Palette information
- corresponding to the category.
- Returns:
- np.ndarray: the drawn image which channel is RGB.
- """
- self.set_image(image)
- if 'bboxes' in instances:
- bboxes = instances.bboxes
- labels = instances.labels
- max_label = int(max(labels) if len(labels) > 0 else 0)
- text_palette = get_palette(self.text_color, max_label + 1)
- text_colors = [text_palette[label] for label in labels]
- bbox_color = palette if self.bbox_color is None \
- else self.bbox_color
- bbox_palette = get_palette(bbox_color, max_label + 1)
- colors = [bbox_palette[label] for label in labels]
- self.draw_bboxes(
- bboxes,
- edge_colors=colors,
- alpha=self.alpha,
- line_widths=self.line_width)
- positions = bboxes[:, :2] + self.line_width
- areas = (bboxes[:, 3] - bboxes[:, 1]) * (
- bboxes[:, 2] - bboxes[:, 0])
- scales = _get_adaptive_scales(areas)
- for i, (pos, label) in enumerate(zip(positions, labels)):
- label_text = classes[
- label] if classes is not None else f'class {label}'
- if 'scores' in instances:
- score = round(float(instances.scores[i]) * 100, 1)
- label_text += f': {score}'
- self.draw_texts(
- label_text,
- pos,
- colors=text_colors[i],
- font_sizes=int(13 * scales[i]),
- bboxes=[{
- 'facecolor': 'black',
- 'alpha': 0.8,
- 'pad': 0.7,
- 'edgecolor': 'none'
- }])
- if 'masks' in instances:
- labels = instances.labels
- masks = instances.masks
- if isinstance(masks, torch.Tensor):
- masks = masks.numpy()
- elif isinstance(masks, (PolygonMasks, BitmapMasks)):
- masks = masks.to_ndarray()
- masks = masks.astype(bool)
- max_label = int(max(labels) if len(labels) > 0 else 0)
- mask_color = palette if self.mask_color is None \
- else self.mask_color
- mask_palette = get_palette(mask_color, max_label + 1)
- colors = [jitter_color(mask_palette[label]) for label in labels]
- text_palette = get_palette(self.text_color, max_label + 1)
- text_colors = [text_palette[label] for label in labels]
- polygons = []
- for i, mask in enumerate(masks):
- contours, _ = bitmap_to_polygon(mask)
- polygons.extend(contours)
- self.draw_polygons(polygons, edge_colors='w', alpha=self.alpha)
- self.draw_binary_masks(masks, colors=colors, alphas=self.alpha)
- if len(labels) > 0 and \
- ('bboxes' not in instances or
- instances.bboxes.sum() == 0):
- # instances.bboxes.sum()==0 represent dummy bboxes.
- # A typical example of SOLO does not exist bbox branch.
- areas = []
- positions = []
- for mask in masks:
- _, _, stats, centroids = cv2.connectedComponentsWithStats(
- mask.astype(np.uint8), connectivity=8)
- if stats.shape[0] > 1:
- largest_id = np.argmax(stats[1:, -1]) + 1
- positions.append(centroids[largest_id])
- areas.append(stats[largest_id, -1])
- areas = np.stack(areas, axis=0)
- scales = _get_adaptive_scales(areas)
- for i, (pos, label) in enumerate(zip(positions, labels)):
- label_text = classes[
- label] if classes is not None else f'class {label}'
- if 'scores' in instances:
- score = round(float(instances.scores[i]) * 100, 1)
- label_text += f': {score}'
- self.draw_texts(
- label_text,
- pos,
- colors=text_colors[i],
- font_sizes=int(13 * scales[i]),
- horizontal_alignments='center',
- bboxes=[{
- 'facecolor': 'black',
- 'alpha': 0.8,
- 'pad': 0.7,
- 'edgecolor': 'none'
- }])
- return self.get_image()
- def _draw_panoptic_seg(self, image: np.ndarray,
- panoptic_seg: ['PixelData'],
- classes: Optional[List[str]]) -> np.ndarray:
- """Draw panoptic seg of GT or prediction.
- Args:
- image (np.ndarray): The image to draw.
- panoptic_seg (:obj:`PixelData`): Data structure for
- pixel-level annotations or predictions.
- classes (List[str], optional): Category information.
- Returns:
- np.ndarray: the drawn image which channel is RGB.
- """
- # TODO: Is there a way to bypass?
- num_classes = len(classes)
- panoptic_seg = panoptic_seg.sem_seg[0]
- ids = np.unique(panoptic_seg)[::-1]
- legal_indices = ids != num_classes # for VOID label
- ids = ids[legal_indices]
- labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64)
- segms = (panoptic_seg[None] == ids[:, None, None])
- max_label = int(max(labels) if len(labels) > 0 else 0)
- mask_palette = get_palette(self.mask_color, max_label + 1)
- colors = [mask_palette[label] for label in labels]
- self.set_image(image)
- # draw segm
- polygons = []
- for i, mask in enumerate(segms):
- contours, _ = bitmap_to_polygon(mask)
- polygons.extend(contours)
- self.draw_polygons(polygons, edge_colors='w', alpha=self.alpha)
- self.draw_binary_masks(segms, colors=colors, alphas=self.alpha)
- # draw label
- areas = []
- positions = []
- for mask in segms:
- _, _, stats, centroids = cv2.connectedComponentsWithStats(
- mask.astype(np.uint8), connectivity=8)
- max_id = np.argmax(stats[1:, -1]) + 1
- positions.append(centroids[max_id])
- areas.append(stats[max_id, -1])
- areas = np.stack(areas, axis=0)
- scales = _get_adaptive_scales(areas)
- text_palette = get_palette(self.text_color, max_label + 1)
- text_colors = [text_palette[label] for label in labels]
- for i, (pos, label) in enumerate(zip(positions, labels)):
- label_text = classes[label]
- self.draw_texts(
- label_text,
- pos,
- colors=text_colors[i],
- font_sizes=int(13 * scales[i]),
- bboxes=[{
- 'facecolor': 'black',
- 'alpha': 0.8,
- 'pad': 0.7,
- 'edgecolor': 'none'
- }],
- horizontal_alignments='center')
- return self.get_image()
- @master_only
- def add_datasample(
- self,
- name: str,
- image: np.ndarray,
- data_sample: Optional['DetDataSample'] = None,
- draw_gt: bool = True,
- draw_pred: bool = True,
- show: bool = False,
- wait_time: float = 0,
- # TODO: Supported in mmengine's Viusalizer.
- out_file: Optional[str] = None,
- pred_score_thr: float = 0.3,
- step: int = 0) -> None:
- """Draw datasample and save to all backends.
- - If GT and prediction are plotted at the same time, they are
- displayed in a stitched image where the left image is the
- ground truth and the right image is the prediction.
- - If ``show`` is True, all storage backends are ignored, and
- the images will be displayed in a local window.
- - If ``out_file`` is specified, the drawn image will be
- saved to ``out_file``. t is usually used when the display
- is not available.
- Args:
- name (str): The image identifier.
- image (np.ndarray): The image to draw.
- data_sample (:obj:`DetDataSample`, optional): A data
- sample that contain annotations and predictions.
- Defaults to None.
- draw_gt (bool): Whether to draw GT DetDataSample. Default to True.
- draw_pred (bool): Whether to draw Prediction DetDataSample.
- Defaults to True.
- show (bool): Whether to display the drawn image. Default to False.
- wait_time (float): The interval of show (s). Defaults to 0.
- out_file (str): Path to output file. Defaults to None.
- pred_score_thr (float): The threshold to visualize the bboxes
- and masks. Defaults to 0.3.
- step (int): Global step value to record. Defaults to 0.
- """
- image = image.clip(0, 255).astype(np.uint8)
- classes = self.dataset_meta.get('classes', None)
- palette = self.dataset_meta.get('palette', None)
- gt_img_data = None
- pred_img_data = None
- if data_sample is not None:
- data_sample = data_sample.cpu()
- if draw_gt and data_sample is not None:
- gt_img_data = image
- if 'gt_instances' in data_sample:
- gt_img_data = self._draw_instances(image,
- data_sample.gt_instances,
- classes, palette)
- if 'gt_panoptic_seg' in data_sample:
- assert classes is not None, 'class information is ' \
- 'not provided when ' \
- 'visualizing panoptic ' \
- 'segmentation results.'
- gt_img_data = self._draw_panoptic_seg(
- gt_img_data, data_sample.gt_panoptic_seg, classes)
- if draw_pred and data_sample is not None:
- pred_img_data = image
- if 'pred_instances' in data_sample:
- pred_instances = data_sample.pred_instances
- pred_instances = pred_instances[
- pred_instances.scores > pred_score_thr]
- pred_img_data = self._draw_instances(image, pred_instances,
- classes, palette)
- if 'pred_panoptic_seg' in data_sample:
- assert classes is not None, 'class information is ' \
- 'not provided when ' \
- 'visualizing panoptic ' \
- 'segmentation results.'
- pred_img_data = self._draw_panoptic_seg(
- pred_img_data, data_sample.pred_panoptic_seg.numpy(),
- classes)
- if gt_img_data is not None and pred_img_data is not None:
- drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1)
- elif gt_img_data is not None:
- drawn_img = gt_img_data
- elif pred_img_data is not None:
- drawn_img = pred_img_data
- else:
- # Display the original image directly if nothing is drawn.
- drawn_img = image
- # It is convenient for users to obtain the drawn image.
- # For example, the user wants to obtain the drawn image and
- # save it as a video during video inference.
- self.set_image(drawn_img)
- if show:
- self.show(drawn_img, win_name=name, wait_time=wait_time)
- if out_file is not None:
- mmcv.imwrite(drawn_img[..., ::-1], out_file)
- else:
- self.add_image(name, drawn_img, step)
|