local_visualizer.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, List, Optional, Tuple, Union
  3. import cv2
  4. import mmcv
  5. import numpy as np
  6. import torch
  7. from mmengine.dist import master_only
  8. from mmengine.structures import InstanceData, PixelData
  9. from mmengine.visualization import Visualizer
  10. from ..evaluation import INSTANCE_OFFSET
  11. from ..registry import VISUALIZERS
  12. from ..structures import DetDataSample
  13. from ..structures.mask import BitmapMasks, PolygonMasks, bitmap_to_polygon
  14. from .palette import _get_adaptive_scales, get_palette, jitter_color
  15. @VISUALIZERS.register_module()
  16. class DetLocalVisualizer(Visualizer):
  17. """MMDetection Local Visualizer.
  18. Args:
  19. name (str): Name of the instance. Defaults to 'visualizer'.
  20. image (np.ndarray, optional): the origin image to draw. The format
  21. should be RGB. Defaults to None.
  22. vis_backends (list, optional): Visual backend config list.
  23. Defaults to None.
  24. save_dir (str, optional): Save file dir for all storage backends.
  25. If it is None, the backend storage will not save any data.
  26. bbox_color (str, tuple(int), optional): Color of bbox lines.
  27. The tuple of color should be in BGR order. Defaults to None.
  28. text_color (str, tuple(int), optional): Color of texts.
  29. The tuple of color should be in BGR order.
  30. Defaults to (200, 200, 200).
  31. mask_color (str, tuple(int), optional): Color of masks.
  32. The tuple of color should be in BGR order.
  33. Defaults to None.
  34. line_width (int, float): The linewidth of lines.
  35. Defaults to 3.
  36. alpha (int, float): The transparency of bboxes or mask.
  37. Defaults to 0.8.
  38. Examples:
  39. >>> import numpy as np
  40. >>> import torch
  41. >>> from mmengine.structures import InstanceData
  42. >>> from mmdet.structures import DetDataSample
  43. >>> from mmdet.visualization import DetLocalVisualizer
  44. >>> det_local_visualizer = DetLocalVisualizer()
  45. >>> image = np.random.randint(0, 256,
  46. ... size=(10, 12, 3)).astype('uint8')
  47. >>> gt_instances = InstanceData()
  48. >>> gt_instances.bboxes = torch.Tensor([[1, 2, 2, 5]])
  49. >>> gt_instances.labels = torch.randint(0, 2, (1,))
  50. >>> gt_det_data_sample = DetDataSample()
  51. >>> gt_det_data_sample.gt_instances = gt_instances
  52. >>> det_local_visualizer.add_datasample('image', image,
  53. ... gt_det_data_sample)
  54. >>> det_local_visualizer.add_datasample(
  55. ... 'image', image, gt_det_data_sample,
  56. ... out_file='out_file.jpg')
  57. >>> det_local_visualizer.add_datasample(
  58. ... 'image', image, gt_det_data_sample,
  59. ... show=True)
  60. >>> pred_instances = InstanceData()
  61. >>> pred_instances.bboxes = torch.Tensor([[2, 4, 4, 8]])
  62. >>> pred_instances.labels = torch.randint(0, 2, (1,))
  63. >>> pred_det_data_sample = DetDataSample()
  64. >>> pred_det_data_sample.pred_instances = pred_instances
  65. >>> det_local_visualizer.add_datasample('image', image,
  66. ... gt_det_data_sample,
  67. ... pred_det_data_sample)
  68. """
  69. def __init__(self,
  70. name: str = 'visualizer',
  71. image: Optional[np.ndarray] = None,
  72. vis_backends: Optional[Dict] = None,
  73. save_dir: Optional[str] = None,
  74. bbox_color: Optional[Union[str, Tuple[int]]] = None,
  75. text_color: Optional[Union[str,
  76. Tuple[int]]] = (200, 200, 200),
  77. mask_color: Optional[Union[str, Tuple[int]]] = None,
  78. line_width: Union[int, float] = 3,
  79. alpha: float = 0.8) -> None:
  80. super().__init__(
  81. name=name,
  82. image=image,
  83. vis_backends=vis_backends,
  84. save_dir=save_dir)
  85. self.bbox_color = bbox_color
  86. self.text_color = text_color
  87. self.mask_color = mask_color
  88. self.line_width = line_width
  89. self.alpha = alpha
  90. # Set default value. When calling
  91. # `DetLocalVisualizer().dataset_meta=xxx`,
  92. # it will override the default value.
  93. self.dataset_meta = {}
  94. def _draw_instances(self, image: np.ndarray, instances: ['InstanceData'],
  95. classes: Optional[List[str]],
  96. palette: Optional[List[tuple]]) -> np.ndarray:
  97. """Draw instances of GT or prediction.
  98. Args:
  99. image (np.ndarray): The image to draw.
  100. instances (:obj:`InstanceData`): Data structure for
  101. instance-level annotations or predictions.
  102. classes (List[str], optional): Category information.
  103. palette (List[tuple], optional): Palette information
  104. corresponding to the category.
  105. Returns:
  106. np.ndarray: the drawn image which channel is RGB.
  107. """
  108. self.set_image(image)
  109. if 'bboxes' in instances:
  110. bboxes = instances.bboxes
  111. labels = instances.labels
  112. max_label = int(max(labels) if len(labels) > 0 else 0)
  113. text_palette = get_palette(self.text_color, max_label + 1)
  114. text_colors = [text_palette[label] for label in labels]
  115. bbox_color = palette if self.bbox_color is None \
  116. else self.bbox_color
  117. bbox_palette = get_palette(bbox_color, max_label + 1)
  118. colors = [bbox_palette[label] for label in labels]
  119. self.draw_bboxes(
  120. bboxes,
  121. edge_colors=colors,
  122. alpha=self.alpha,
  123. line_widths=self.line_width)
  124. positions = bboxes[:, :2] + self.line_width
  125. areas = (bboxes[:, 3] - bboxes[:, 1]) * (
  126. bboxes[:, 2] - bboxes[:, 0])
  127. scales = _get_adaptive_scales(areas)
  128. for i, (pos, label) in enumerate(zip(positions, labels)):
  129. label_text = classes[
  130. label] if classes is not None else f'class {label}'
  131. if 'scores' in instances:
  132. score = round(float(instances.scores[i]) * 100, 1)
  133. label_text += f': {score}'
  134. self.draw_texts(
  135. label_text,
  136. pos,
  137. colors=text_colors[i],
  138. font_sizes=int(13 * scales[i]),
  139. bboxes=[{
  140. 'facecolor': 'black',
  141. 'alpha': 0.8,
  142. 'pad': 0.7,
  143. 'edgecolor': 'none'
  144. }])
  145. if 'masks' in instances:
  146. labels = instances.labels
  147. masks = instances.masks
  148. if isinstance(masks, torch.Tensor):
  149. masks = masks.numpy()
  150. elif isinstance(masks, (PolygonMasks, BitmapMasks)):
  151. masks = masks.to_ndarray()
  152. masks = masks.astype(bool)
  153. max_label = int(max(labels) if len(labels) > 0 else 0)
  154. mask_color = palette if self.mask_color is None \
  155. else self.mask_color
  156. mask_palette = get_palette(mask_color, max_label + 1)
  157. colors = [jitter_color(mask_palette[label]) for label in labels]
  158. text_palette = get_palette(self.text_color, max_label + 1)
  159. text_colors = [text_palette[label] for label in labels]
  160. polygons = []
  161. for i, mask in enumerate(masks):
  162. contours, _ = bitmap_to_polygon(mask)
  163. polygons.extend(contours)
  164. self.draw_polygons(polygons, edge_colors='w', alpha=self.alpha)
  165. self.draw_binary_masks(masks, colors=colors, alphas=self.alpha)
  166. if len(labels) > 0 and \
  167. ('bboxes' not in instances or
  168. instances.bboxes.sum() == 0):
  169. # instances.bboxes.sum()==0 represent dummy bboxes.
  170. # A typical example of SOLO does not exist bbox branch.
  171. areas = []
  172. positions = []
  173. for mask in masks:
  174. _, _, stats, centroids = cv2.connectedComponentsWithStats(
  175. mask.astype(np.uint8), connectivity=8)
  176. if stats.shape[0] > 1:
  177. largest_id = np.argmax(stats[1:, -1]) + 1
  178. positions.append(centroids[largest_id])
  179. areas.append(stats[largest_id, -1])
  180. areas = np.stack(areas, axis=0)
  181. scales = _get_adaptive_scales(areas)
  182. for i, (pos, label) in enumerate(zip(positions, labels)):
  183. label_text = classes[
  184. label] if classes is not None else f'class {label}'
  185. if 'scores' in instances:
  186. score = round(float(instances.scores[i]) * 100, 1)
  187. label_text += f': {score}'
  188. self.draw_texts(
  189. label_text,
  190. pos,
  191. colors=text_colors[i],
  192. font_sizes=int(13 * scales[i]),
  193. horizontal_alignments='center',
  194. bboxes=[{
  195. 'facecolor': 'black',
  196. 'alpha': 0.8,
  197. 'pad': 0.7,
  198. 'edgecolor': 'none'
  199. }])
  200. return self.get_image()
  201. def _draw_panoptic_seg(self, image: np.ndarray,
  202. panoptic_seg: ['PixelData'],
  203. classes: Optional[List[str]]) -> np.ndarray:
  204. """Draw panoptic seg of GT or prediction.
  205. Args:
  206. image (np.ndarray): The image to draw.
  207. panoptic_seg (:obj:`PixelData`): Data structure for
  208. pixel-level annotations or predictions.
  209. classes (List[str], optional): Category information.
  210. Returns:
  211. np.ndarray: the drawn image which channel is RGB.
  212. """
  213. # TODO: Is there a way to bypass?
  214. num_classes = len(classes)
  215. panoptic_seg = panoptic_seg.sem_seg[0]
  216. ids = np.unique(panoptic_seg)[::-1]
  217. legal_indices = ids != num_classes # for VOID label
  218. ids = ids[legal_indices]
  219. labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64)
  220. segms = (panoptic_seg[None] == ids[:, None, None])
  221. max_label = int(max(labels) if len(labels) > 0 else 0)
  222. mask_palette = get_palette(self.mask_color, max_label + 1)
  223. colors = [mask_palette[label] for label in labels]
  224. self.set_image(image)
  225. # draw segm
  226. polygons = []
  227. for i, mask in enumerate(segms):
  228. contours, _ = bitmap_to_polygon(mask)
  229. polygons.extend(contours)
  230. self.draw_polygons(polygons, edge_colors='w', alpha=self.alpha)
  231. self.draw_binary_masks(segms, colors=colors, alphas=self.alpha)
  232. # draw label
  233. areas = []
  234. positions = []
  235. for mask in segms:
  236. _, _, stats, centroids = cv2.connectedComponentsWithStats(
  237. mask.astype(np.uint8), connectivity=8)
  238. max_id = np.argmax(stats[1:, -1]) + 1
  239. positions.append(centroids[max_id])
  240. areas.append(stats[max_id, -1])
  241. areas = np.stack(areas, axis=0)
  242. scales = _get_adaptive_scales(areas)
  243. text_palette = get_palette(self.text_color, max_label + 1)
  244. text_colors = [text_palette[label] for label in labels]
  245. for i, (pos, label) in enumerate(zip(positions, labels)):
  246. label_text = classes[label]
  247. self.draw_texts(
  248. label_text,
  249. pos,
  250. colors=text_colors[i],
  251. font_sizes=int(13 * scales[i]),
  252. bboxes=[{
  253. 'facecolor': 'black',
  254. 'alpha': 0.8,
  255. 'pad': 0.7,
  256. 'edgecolor': 'none'
  257. }],
  258. horizontal_alignments='center')
  259. return self.get_image()
  260. @master_only
  261. def add_datasample(
  262. self,
  263. name: str,
  264. image: np.ndarray,
  265. data_sample: Optional['DetDataSample'] = None,
  266. draw_gt: bool = True,
  267. draw_pred: bool = True,
  268. show: bool = False,
  269. wait_time: float = 0,
  270. # TODO: Supported in mmengine's Viusalizer.
  271. out_file: Optional[str] = None,
  272. pred_score_thr: float = 0.3,
  273. step: int = 0) -> None:
  274. """Draw datasample and save to all backends.
  275. - If GT and prediction are plotted at the same time, they are
  276. displayed in a stitched image where the left image is the
  277. ground truth and the right image is the prediction.
  278. - If ``show`` is True, all storage backends are ignored, and
  279. the images will be displayed in a local window.
  280. - If ``out_file`` is specified, the drawn image will be
  281. saved to ``out_file``. t is usually used when the display
  282. is not available.
  283. Args:
  284. name (str): The image identifier.
  285. image (np.ndarray): The image to draw.
  286. data_sample (:obj:`DetDataSample`, optional): A data
  287. sample that contain annotations and predictions.
  288. Defaults to None.
  289. draw_gt (bool): Whether to draw GT DetDataSample. Default to True.
  290. draw_pred (bool): Whether to draw Prediction DetDataSample.
  291. Defaults to True.
  292. show (bool): Whether to display the drawn image. Default to False.
  293. wait_time (float): The interval of show (s). Defaults to 0.
  294. out_file (str): Path to output file. Defaults to None.
  295. pred_score_thr (float): The threshold to visualize the bboxes
  296. and masks. Defaults to 0.3.
  297. step (int): Global step value to record. Defaults to 0.
  298. """
  299. image = image.clip(0, 255).astype(np.uint8)
  300. classes = self.dataset_meta.get('classes', None)
  301. palette = self.dataset_meta.get('palette', None)
  302. gt_img_data = None
  303. pred_img_data = None
  304. if data_sample is not None:
  305. data_sample = data_sample.cpu()
  306. if draw_gt and data_sample is not None:
  307. gt_img_data = image
  308. if 'gt_instances' in data_sample:
  309. gt_img_data = self._draw_instances(image,
  310. data_sample.gt_instances,
  311. classes, palette)
  312. if 'gt_panoptic_seg' in data_sample:
  313. assert classes is not None, 'class information is ' \
  314. 'not provided when ' \
  315. 'visualizing panoptic ' \
  316. 'segmentation results.'
  317. gt_img_data = self._draw_panoptic_seg(
  318. gt_img_data, data_sample.gt_panoptic_seg, classes)
  319. if draw_pred and data_sample is not None:
  320. pred_img_data = image
  321. if 'pred_instances' in data_sample:
  322. pred_instances = data_sample.pred_instances
  323. pred_instances = pred_instances[
  324. pred_instances.scores > pred_score_thr]
  325. pred_img_data = self._draw_instances(image, pred_instances,
  326. classes, palette)
  327. if 'pred_panoptic_seg' in data_sample:
  328. assert classes is not None, 'class information is ' \
  329. 'not provided when ' \
  330. 'visualizing panoptic ' \
  331. 'segmentation results.'
  332. pred_img_data = self._draw_panoptic_seg(
  333. pred_img_data, data_sample.pred_panoptic_seg.numpy(),
  334. classes)
  335. if gt_img_data is not None and pred_img_data is not None:
  336. drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1)
  337. elif gt_img_data is not None:
  338. drawn_img = gt_img_data
  339. elif pred_img_data is not None:
  340. drawn_img = pred_img_data
  341. else:
  342. # Display the original image directly if nothing is drawn.
  343. drawn_img = image
  344. # It is convenient for users to obtain the drawn image.
  345. # For example, the user wants to obtain the drawn image and
  346. # save it as a video during video inference.
  347. self.set_image(drawn_img)
  348. if show:
  349. self.show(drawn_img, win_name=name, wait_time=wait_time)
  350. if out_file is not None:
  351. mmcv.imwrite(drawn_img[..., ::-1], out_file)
  352. else:
  353. self.add_image(name, drawn_img, step)