det_inferencer.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import os.path as osp
  4. import warnings
  5. from typing import Dict, Iterable, List, Optional, Sequence, Union
  6. import mmcv
  7. import mmengine
  8. import numpy as np
  9. import torch.nn as nn
  10. from mmengine.dataset import Compose
  11. from mmengine.fileio import (get_file_backend, isdir, join_path,
  12. list_dir_or_file)
  13. from mmengine.infer.infer import BaseInferencer, ModelType
  14. from mmengine.model.utils import revert_sync_batchnorm
  15. from mmengine.registry import init_default_scope
  16. from mmengine.runner.checkpoint import _load_checkpoint_to_model
  17. from mmengine.visualization import Visualizer
  18. from rich.progress import track
  19. from mmdet.evaluation import INSTANCE_OFFSET
  20. from mmdet.registry import DATASETS
  21. from mmdet.structures import DetDataSample
  22. from mmdet.structures.mask import encode_mask_results, mask2bbox
  23. from mmdet.utils import ConfigType
  24. from ..evaluation import get_classes
  25. try:
  26. from panopticapi.evaluation import VOID
  27. from panopticapi.utils import id2rgb
  28. except ImportError:
  29. id2rgb = None
  30. VOID = None
  31. InputType = Union[str, np.ndarray]
  32. InputsType = Union[InputType, Sequence[InputType]]
  33. PredType = List[DetDataSample]
  34. ImgType = Union[np.ndarray, Sequence[np.ndarray]]
  35. IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
  36. '.tiff', '.webp')
  37. class DetInferencer(BaseInferencer):
  38. """Object Detection Inferencer.
  39. Args:
  40. model (str, optional): Path to the config file or the model name
  41. defined in metafile. For example, it could be
  42. "rtmdet-s" or 'rtmdet_s_8xb32-300e_coco' or
  43. "configs/rtmdet/rtmdet_s_8xb32-300e_coco.py".
  44. If model is not specified, user must provide the
  45. `weights` saved by MMEngine which contains the config string.
  46. Defaults to None.
  47. weights (str, optional): Path to the checkpoint. If it is not specified
  48. and model is a model name of metafile, the weights will be loaded
  49. from metafile. Defaults to None.
  50. device (str, optional): Device to run inference. If None, the available
  51. device will be automatically used. Defaults to None.
  52. scope (str, optional): The scope of the model. Defaults to mmdet.
  53. palette (str): Color palette used for visualization. The order of
  54. priority is palette -> config -> checkpoint. Defaults to 'none'.
  55. """
  56. preprocess_kwargs: set = set()
  57. forward_kwargs: set = set()
  58. visualize_kwargs: set = {
  59. 'return_vis',
  60. 'show',
  61. 'wait_time',
  62. 'draw_pred',
  63. 'pred_score_thr',
  64. 'img_out_dir',
  65. 'no_save_vis',
  66. }
  67. postprocess_kwargs: set = {
  68. 'print_result',
  69. 'pred_out_dir',
  70. 'return_datasample',
  71. 'no_save_pred',
  72. }
  73. def __init__(self,
  74. model: Optional[Union[ModelType, str]] = None,
  75. weights: Optional[str] = None,
  76. device: Optional[str] = None,
  77. scope: Optional[str] = 'mmdet',
  78. palette: str = 'none') -> None:
  79. # A global counter tracking the number of images processed, for
  80. # naming of the output images
  81. self.num_visualized_imgs = 0
  82. self.num_predicted_imgs = 0
  83. self.palette = palette
  84. init_default_scope(scope)
  85. super().__init__(
  86. model=model, weights=weights, device=device, scope=scope)
  87. self.model = revert_sync_batchnorm(self.model)
  88. def _load_weights_to_model(self, model: nn.Module,
  89. checkpoint: Optional[dict],
  90. cfg: Optional[ConfigType]) -> None:
  91. """Loading model weights and meta information from cfg and checkpoint.
  92. Args:
  93. model (nn.Module): Model to load weights and meta information.
  94. checkpoint (dict, optional): The loaded checkpoint.
  95. cfg (Config or ConfigDict, optional): The loaded config.
  96. """
  97. if checkpoint is not None:
  98. _load_checkpoint_to_model(model, checkpoint)
  99. checkpoint_meta = checkpoint.get('meta', {})
  100. # save the dataset_meta in the model for convenience
  101. if 'dataset_meta' in checkpoint_meta:
  102. # mmdet 3.x, all keys should be lowercase
  103. model.dataset_meta = {
  104. k.lower(): v
  105. for k, v in checkpoint_meta['dataset_meta'].items()
  106. }
  107. elif 'CLASSES' in checkpoint_meta:
  108. # < mmdet 3.x
  109. classes = checkpoint_meta['CLASSES']
  110. model.dataset_meta = {'classes': classes}
  111. else:
  112. warnings.warn(
  113. 'dataset_meta or class names are not saved in the '
  114. 'checkpoint\'s meta data, use COCO classes by default.')
  115. model.dataset_meta = {'classes': get_classes('coco')}
  116. else:
  117. warnings.warn('Checkpoint is not loaded, and the inference '
  118. 'result is calculated by the randomly initialized '
  119. 'model!')
  120. warnings.warn('weights is None, use COCO classes by default.')
  121. model.dataset_meta = {'classes': get_classes('coco')}
  122. # Priority: args.palette -> config -> checkpoint
  123. if self.palette != 'none':
  124. model.dataset_meta['palette'] = self.palette
  125. else:
  126. test_dataset_cfg = copy.deepcopy(cfg.test_dataloader.dataset)
  127. # lazy init. We only need the metainfo.
  128. test_dataset_cfg['lazy_init'] = True
  129. metainfo = DATASETS.build(test_dataset_cfg).metainfo
  130. cfg_palette = metainfo.get('palette', None)
  131. if cfg_palette is not None:
  132. model.dataset_meta['palette'] = cfg_palette
  133. else:
  134. if 'palette' not in model.dataset_meta:
  135. warnings.warn(
  136. 'palette does not exist, random is used by default. '
  137. 'You can also set the palette to customize.')
  138. model.dataset_meta['palette'] = 'random'
  139. def _init_pipeline(self, cfg: ConfigType) -> Compose:
  140. """Initialize the test pipeline."""
  141. pipeline_cfg = cfg.test_dataloader.dataset.pipeline
  142. # For inference, the key of ``img_id`` is not used.
  143. if 'meta_keys' in pipeline_cfg[-1]:
  144. pipeline_cfg[-1]['meta_keys'] = tuple(
  145. meta_key for meta_key in pipeline_cfg[-1]['meta_keys']
  146. if meta_key != 'img_id')
  147. load_img_idx = self._get_transform_idx(pipeline_cfg,
  148. 'LoadImageFromFile')
  149. if load_img_idx == -1:
  150. raise ValueError(
  151. 'LoadImageFromFile is not found in the test pipeline')
  152. pipeline_cfg[load_img_idx]['type'] = 'mmdet.InferencerLoader'
  153. return Compose(pipeline_cfg)
  154. def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int:
  155. """Returns the index of the transform in a pipeline.
  156. If the transform is not found, returns -1.
  157. """
  158. for i, transform in enumerate(pipeline_cfg):
  159. if transform['type'] == name:
  160. return i
  161. return -1
  162. def _init_visualizer(self, cfg: ConfigType) -> Optional[Visualizer]:
  163. """Initialize visualizers.
  164. Args:
  165. cfg (ConfigType): Config containing the visualizer information.
  166. Returns:
  167. Visualizer or None: Visualizer initialized with config.
  168. """
  169. visualizer = super()._init_visualizer(cfg)
  170. visualizer.dataset_meta = self.model.dataset_meta
  171. return visualizer
  172. def _inputs_to_list(self, inputs: InputsType) -> list:
  173. """Preprocess the inputs to a list.
  174. Preprocess inputs to a list according to its type:
  175. - list or tuple: return inputs
  176. - str:
  177. - Directory path: return all files in the directory
  178. - other cases: return a list containing the string. The string
  179. could be a path to file, a url or other types of string according
  180. to the task.
  181. Args:
  182. inputs (InputsType): Inputs for the inferencer.
  183. Returns:
  184. list: List of input for the :meth:`preprocess`.
  185. """
  186. if isinstance(inputs, str):
  187. backend = get_file_backend(inputs)
  188. if hasattr(backend, 'isdir') and isdir(inputs):
  189. # Backends like HttpsBackend do not implement `isdir`, so only
  190. # those backends that implement `isdir` could accept the inputs
  191. # as a directory
  192. filename_list = list_dir_or_file(
  193. inputs, list_dir=False, suffix=IMG_EXTENSIONS)
  194. inputs = [
  195. join_path(inputs, filename) for filename in filename_list
  196. ]
  197. if not isinstance(inputs, (list, tuple)):
  198. inputs = [inputs]
  199. return list(inputs)
  200. def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs):
  201. """Process the inputs into a model-feedable format.
  202. Customize your preprocess by overriding this method. Preprocess should
  203. return an iterable object, of which each item will be used as the
  204. input of ``model.test_step``.
  205. ``BaseInferencer.preprocess`` will return an iterable chunked data,
  206. which will be used in __call__ like this:
  207. .. code-block:: python
  208. def __call__(self, inputs, batch_size=1, **kwargs):
  209. chunked_data = self.preprocess(inputs, batch_size, **kwargs)
  210. for batch in chunked_data:
  211. preds = self.forward(batch, **kwargs)
  212. Args:
  213. inputs (InputsType): Inputs given by user.
  214. batch_size (int): batch size. Defaults to 1.
  215. Yields:
  216. Any: Data processed by the ``pipeline`` and ``collate_fn``.
  217. """
  218. chunked_data = self._get_chunk_data(inputs, batch_size)
  219. yield from map(self.collate_fn, chunked_data)
  220. def _get_chunk_data(self, inputs: Iterable, chunk_size: int):
  221. """Get batch data from inputs.
  222. Args:
  223. inputs (Iterable): An iterable dataset.
  224. chunk_size (int): Equivalent to batch size.
  225. Yields:
  226. list: batch data.
  227. """
  228. inputs_iter = iter(inputs)
  229. while True:
  230. try:
  231. chunk_data = []
  232. for _ in range(chunk_size):
  233. inputs_ = next(inputs_iter)
  234. chunk_data.append((inputs_, self.pipeline(inputs_)))
  235. yield chunk_data
  236. except StopIteration:
  237. if chunk_data:
  238. yield chunk_data
  239. break
  240. # TODO: Video and Webcam are currently not supported and
  241. # may consume too much memory if your input folder has a lot of images.
  242. # We will be optimized later.
  243. def __call__(self,
  244. inputs: InputsType,
  245. batch_size: int = 1,
  246. return_vis: bool = False,
  247. show: bool = False,
  248. wait_time: int = 0,
  249. no_save_vis: bool = False,
  250. draw_pred: bool = True,
  251. pred_score_thr: float = 0.3,
  252. return_datasample: bool = False,
  253. print_result: bool = False,
  254. no_save_pred: bool = True,
  255. out_dir: str = '',
  256. **kwargs) -> dict:
  257. """Call the inferencer.
  258. Args:
  259. inputs (InputsType): Inputs for the inferencer.
  260. batch_size (int): Inference batch size. Defaults to 1.
  261. show (bool): Whether to display the visualization results in a
  262. popup window. Defaults to False.
  263. wait_time (float): The interval of show (s). Defaults to 0.
  264. no_save_vis (bool): Whether to force not to save prediction
  265. vis results. Defaults to False.
  266. draw_pred (bool): Whether to draw predicted bounding boxes.
  267. Defaults to True.
  268. pred_score_thr (float): Minimum score of bboxes to draw.
  269. Defaults to 0.3.
  270. return_datasample (bool): Whether to return results as
  271. :obj:`DetDataSample`. Defaults to False.
  272. print_result (bool): Whether to print the inference result w/o
  273. visualization to the console. Defaults to False.
  274. no_save_pred (bool): Whether to force not to save prediction
  275. results. Defaults to True.
  276. out_file: Dir to save the inference results or
  277. visualization. If left as empty, no file will be saved.
  278. Defaults to ''.
  279. **kwargs: Other keyword arguments passed to :meth:`preprocess`,
  280. :meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
  281. Each key in kwargs should be in the corresponding set of
  282. ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
  283. and ``postprocess_kwargs``.
  284. Returns:
  285. dict: Inference and visualization results.
  286. """
  287. (
  288. preprocess_kwargs,
  289. forward_kwargs,
  290. visualize_kwargs,
  291. postprocess_kwargs,
  292. ) = self._dispatch_kwargs(**kwargs)
  293. ori_inputs = self._inputs_to_list(inputs)
  294. inputs = self.preprocess(
  295. ori_inputs, batch_size=batch_size, **preprocess_kwargs)
  296. results_dict = {'predictions': [], 'visualization': []}
  297. for ori_inputs, data in track(inputs, description='Inference'):
  298. preds = self.forward(data, **forward_kwargs)
  299. visualization = self.visualize(
  300. ori_inputs,
  301. preds,
  302. return_vis=return_vis,
  303. show=show,
  304. wait_time=wait_time,
  305. draw_pred=draw_pred,
  306. pred_score_thr=pred_score_thr,
  307. no_save_vis=no_save_vis,
  308. img_out_dir=out_dir,
  309. **visualize_kwargs)
  310. results = self.postprocess(
  311. preds,
  312. visualization,
  313. return_datasample=return_datasample,
  314. print_result=print_result,
  315. no_save_pred=no_save_pred,
  316. pred_out_dir=out_dir,
  317. **postprocess_kwargs)
  318. results_dict['predictions'].extend(results['predictions'])
  319. if results['visualization'] is not None:
  320. results_dict['visualization'].extend(results['visualization'])
  321. return results_dict
  322. def visualize(self,
  323. inputs: InputsType,
  324. preds: PredType,
  325. return_vis: bool = False,
  326. show: bool = False,
  327. wait_time: int = 0,
  328. draw_pred: bool = True,
  329. pred_score_thr: float = 0.3,
  330. no_save_vis: bool = False,
  331. img_out_dir: str = '',
  332. **kwargs) -> Union[List[np.ndarray], None]:
  333. """Visualize predictions.
  334. Args:
  335. inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer.
  336. preds (List[:obj:`DetDataSample`]): Predictions of the model.
  337. return_vis (bool): Whether to return the visualization result.
  338. Defaults to False.
  339. show (bool): Whether to display the image in a popup window.
  340. Defaults to False.
  341. wait_time (float): The interval of show (s). Defaults to 0.
  342. draw_pred (bool): Whether to draw predicted bounding boxes.
  343. Defaults to True.
  344. pred_score_thr (float): Minimum score of bboxes to draw.
  345. Defaults to 0.3.
  346. no_save_vis (bool): Whether to force not to save prediction
  347. vis results. Defaults to False.
  348. img_out_dir (str): Output directory of visualization results.
  349. If left as empty, no file will be saved. Defaults to ''.
  350. Returns:
  351. List[np.ndarray] or None: Returns visualization results only if
  352. applicable.
  353. """
  354. if no_save_vis is True:
  355. img_out_dir = ''
  356. if not show and img_out_dir == '' and not return_vis:
  357. return None
  358. if self.visualizer is None:
  359. raise ValueError('Visualization needs the "visualizer" term'
  360. 'defined in the config, but got None.')
  361. results = []
  362. for single_input, pred in zip(inputs, preds):
  363. if isinstance(single_input, str):
  364. img_bytes = mmengine.fileio.get(single_input)
  365. img = mmcv.imfrombytes(img_bytes)
  366. img = img[:, :, ::-1]
  367. img_name = osp.basename(single_input)
  368. elif isinstance(single_input, np.ndarray):
  369. img = single_input.copy()
  370. img_num = str(self.num_visualized_imgs).zfill(8)
  371. img_name = f'{img_num}.jpg'
  372. else:
  373. raise ValueError('Unsupported input type: '
  374. f'{type(single_input)}')
  375. out_file = osp.join(img_out_dir, 'vis',
  376. img_name) if img_out_dir != '' else None
  377. self.visualizer.add_datasample(
  378. img_name,
  379. img,
  380. pred,
  381. show=show,
  382. wait_time=wait_time,
  383. draw_gt=False,
  384. draw_pred=draw_pred,
  385. pred_score_thr=pred_score_thr,
  386. out_file=out_file,
  387. )
  388. results.append(self.visualizer.get_image())
  389. self.num_visualized_imgs += 1
  390. return results
  391. def postprocess(
  392. self,
  393. preds: PredType,
  394. visualization: Optional[List[np.ndarray]] = None,
  395. return_datasample: bool = False,
  396. print_result: bool = False,
  397. no_save_pred: bool = False,
  398. pred_out_dir: str = '',
  399. **kwargs,
  400. ) -> Dict:
  401. """Process the predictions and visualization results from ``forward``
  402. and ``visualize``.
  403. This method should be responsible for the following tasks:
  404. 1. Convert datasamples into a json-serializable dict if needed.
  405. 2. Pack the predictions and visualization results and return them.
  406. 3. Dump or log the predictions.
  407. Args:
  408. preds (List[:obj:`DetDataSample`]): Predictions of the model.
  409. visualization (Optional[np.ndarray]): Visualized predictions.
  410. return_datasample (bool): Whether to use Datasample to store
  411. inference results. If False, dict will be used.
  412. print_result (bool): Whether to print the inference result w/o
  413. visualization to the console. Defaults to False.
  414. no_save_pred (bool): Whether to force not to save prediction
  415. results. Defaults to False.
  416. pred_out_dir: Dir to save the inference results w/o
  417. visualization. If left as empty, no file will be saved.
  418. Defaults to ''.
  419. Returns:
  420. dict: Inference and visualization results with key ``predictions``
  421. and ``visualization``.
  422. - ``visualization`` (Any): Returned by :meth:`visualize`.
  423. - ``predictions`` (dict or DataSample): Returned by
  424. :meth:`forward` and processed in :meth:`postprocess`.
  425. If ``return_datasample=False``, it usually should be a
  426. json-serializable dict containing only basic data elements such
  427. as strings and numbers.
  428. """
  429. if no_save_pred is True:
  430. pred_out_dir = ''
  431. result_dict = {}
  432. results = preds
  433. if not return_datasample:
  434. results = []
  435. for pred in preds:
  436. result = self.pred2dict(pred, pred_out_dir)
  437. results.append(result)
  438. elif pred_out_dir != '':
  439. warnings.warn('Currently does not support saving datasample '
  440. 'when return_datasample is set to True. '
  441. 'Prediction results are not saved!')
  442. # Add img to the results after printing and dumping
  443. result_dict['predictions'] = results
  444. if print_result:
  445. print(result_dict)
  446. result_dict['visualization'] = visualization
  447. return result_dict
  448. # TODO: The data format and fields saved in json need further discussion.
  449. # Maybe should include model name, timestamp, filename, image info etc.
  450. def pred2dict(self,
  451. data_sample: DetDataSample,
  452. pred_out_dir: str = '') -> Dict:
  453. """Extract elements necessary to represent a prediction into a
  454. dictionary.
  455. It's better to contain only basic data elements such as strings and
  456. numbers in order to guarantee it's json-serializable.
  457. Args:
  458. data_sample (:obj:`DetDataSample`): Predictions of the model.
  459. pred_out_dir: Dir to save the inference results w/o
  460. visualization. If left as empty, no file will be saved.
  461. Defaults to ''.
  462. Returns:
  463. dict: Prediction results.
  464. """
  465. is_save_pred = True
  466. if pred_out_dir == '':
  467. is_save_pred = False
  468. if is_save_pred and 'img_path' in data_sample:
  469. img_path = osp.basename(data_sample.img_path)
  470. img_path = osp.splitext(img_path)[0]
  471. out_img_path = osp.join(pred_out_dir, 'preds',
  472. img_path + '_panoptic_seg.png')
  473. out_json_path = osp.join(pred_out_dir, 'preds', img_path + '.json')
  474. elif is_save_pred:
  475. out_img_path = osp.join(
  476. pred_out_dir, 'preds',
  477. f'{self.num_predicted_imgs}_panoptic_seg.png')
  478. out_json_path = osp.join(pred_out_dir, 'preds',
  479. f'{self.num_predicted_imgs}.json')
  480. self.num_predicted_imgs += 1
  481. result = {}
  482. if 'pred_instances' in data_sample:
  483. masks = data_sample.pred_instances.get('masks')
  484. pred_instances = data_sample.pred_instances.numpy()
  485. result = {
  486. 'bboxes': pred_instances.bboxes.tolist(),
  487. 'labels': pred_instances.labels.tolist(),
  488. 'scores': pred_instances.scores.tolist()
  489. }
  490. if masks is not None:
  491. if pred_instances.bboxes.sum() == 0:
  492. # Fake bbox, such as the SOLO.
  493. bboxes = mask2bbox(masks.cpu()).numpy().tolist()
  494. result['bboxes'] = bboxes
  495. encode_masks = encode_mask_results(pred_instances.masks)
  496. for encode_mask in encode_masks:
  497. if isinstance(encode_mask['counts'], bytes):
  498. encode_mask['counts'] = encode_mask['counts'].decode()
  499. result['masks'] = encode_masks
  500. if 'pred_panoptic_seg' in data_sample:
  501. if VOID is None:
  502. raise RuntimeError(
  503. 'panopticapi is not installed, please install it by: '
  504. 'pip install git+https://github.com/cocodataset/'
  505. 'panopticapi.git.')
  506. pan = data_sample.pred_panoptic_seg.sem_seg.cpu().numpy()[0]
  507. pan[pan % INSTANCE_OFFSET == len(
  508. self.model.dataset_meta['classes'])] = VOID
  509. pan = id2rgb(pan).astype(np.uint8)
  510. if is_save_pred:
  511. mmcv.imwrite(pan[:, :, ::-1], out_img_path)
  512. result['panoptic_seg_path'] = out_img_path
  513. else:
  514. result['panoptic_seg'] = pan
  515. if is_save_pred:
  516. mmengine.dump(result, out_json_path)
  517. return result