123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import copy
- import os.path as osp
- import warnings
- from typing import Dict, Iterable, List, Optional, Sequence, Union
- import mmcv
- import mmengine
- import numpy as np
- import torch.nn as nn
- from mmengine.dataset import Compose
- from mmengine.fileio import (get_file_backend, isdir, join_path,
- list_dir_or_file)
- from mmengine.infer.infer import BaseInferencer, ModelType
- from mmengine.model.utils import revert_sync_batchnorm
- from mmengine.registry import init_default_scope
- from mmengine.runner.checkpoint import _load_checkpoint_to_model
- from mmengine.visualization import Visualizer
- from rich.progress import track
- from mmdet.evaluation import INSTANCE_OFFSET
- from mmdet.registry import DATASETS
- from mmdet.structures import DetDataSample
- from mmdet.structures.mask import encode_mask_results, mask2bbox
- from mmdet.utils import ConfigType
- from ..evaluation import get_classes
- try:
- from panopticapi.evaluation import VOID
- from panopticapi.utils import id2rgb
- except ImportError:
- id2rgb = None
- VOID = None
- InputType = Union[str, np.ndarray]
- InputsType = Union[InputType, Sequence[InputType]]
- PredType = List[DetDataSample]
- ImgType = Union[np.ndarray, Sequence[np.ndarray]]
- IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
- '.tiff', '.webp')
- class DetInferencer(BaseInferencer):
- """Object Detection Inferencer.
- Args:
- model (str, optional): Path to the config file or the model name
- defined in metafile. For example, it could be
- "rtmdet-s" or 'rtmdet_s_8xb32-300e_coco' or
- "configs/rtmdet/rtmdet_s_8xb32-300e_coco.py".
- If model is not specified, user must provide the
- `weights` saved by MMEngine which contains the config string.
- Defaults to None.
- weights (str, optional): Path to the checkpoint. If it is not specified
- and model is a model name of metafile, the weights will be loaded
- from metafile. Defaults to None.
- device (str, optional): Device to run inference. If None, the available
- device will be automatically used. Defaults to None.
- scope (str, optional): The scope of the model. Defaults to mmdet.
- palette (str): Color palette used for visualization. The order of
- priority is palette -> config -> checkpoint. Defaults to 'none'.
- """
- preprocess_kwargs: set = set()
- forward_kwargs: set = set()
- visualize_kwargs: set = {
- 'return_vis',
- 'show',
- 'wait_time',
- 'draw_pred',
- 'pred_score_thr',
- 'img_out_dir',
- 'no_save_vis',
- }
- postprocess_kwargs: set = {
- 'print_result',
- 'pred_out_dir',
- 'return_datasample',
- 'no_save_pred',
- }
- def __init__(self,
- model: Optional[Union[ModelType, str]] = None,
- weights: Optional[str] = None,
- device: Optional[str] = None,
- scope: Optional[str] = 'mmdet',
- palette: str = 'none') -> None:
- # A global counter tracking the number of images processed, for
- # naming of the output images
- self.num_visualized_imgs = 0
- self.num_predicted_imgs = 0
- self.palette = palette
- init_default_scope(scope)
- super().__init__(
- model=model, weights=weights, device=device, scope=scope)
- self.model = revert_sync_batchnorm(self.model)
- def _load_weights_to_model(self, model: nn.Module,
- checkpoint: Optional[dict],
- cfg: Optional[ConfigType]) -> None:
- """Loading model weights and meta information from cfg and checkpoint.
- Args:
- model (nn.Module): Model to load weights and meta information.
- checkpoint (dict, optional): The loaded checkpoint.
- cfg (Config or ConfigDict, optional): The loaded config.
- """
- if checkpoint is not None:
- _load_checkpoint_to_model(model, checkpoint)
- checkpoint_meta = checkpoint.get('meta', {})
- # save the dataset_meta in the model for convenience
- if 'dataset_meta' in checkpoint_meta:
- # mmdet 3.x, all keys should be lowercase
- model.dataset_meta = {
- k.lower(): v
- for k, v in checkpoint_meta['dataset_meta'].items()
- }
- elif 'CLASSES' in checkpoint_meta:
- # < mmdet 3.x
- classes = checkpoint_meta['CLASSES']
- model.dataset_meta = {'classes': classes}
- else:
- warnings.warn(
- 'dataset_meta or class names are not saved in the '
- 'checkpoint\'s meta data, use COCO classes by default.')
- model.dataset_meta = {'classes': get_classes('coco')}
- else:
- warnings.warn('Checkpoint is not loaded, and the inference '
- 'result is calculated by the randomly initialized '
- 'model!')
- warnings.warn('weights is None, use COCO classes by default.')
- model.dataset_meta = {'classes': get_classes('coco')}
- # Priority: args.palette -> config -> checkpoint
- if self.palette != 'none':
- model.dataset_meta['palette'] = self.palette
- else:
- test_dataset_cfg = copy.deepcopy(cfg.test_dataloader.dataset)
- # lazy init. We only need the metainfo.
- test_dataset_cfg['lazy_init'] = True
- metainfo = DATASETS.build(test_dataset_cfg).metainfo
- cfg_palette = metainfo.get('palette', None)
- if cfg_palette is not None:
- model.dataset_meta['palette'] = cfg_palette
- else:
- if 'palette' not in model.dataset_meta:
- warnings.warn(
- 'palette does not exist, random is used by default. '
- 'You can also set the palette to customize.')
- model.dataset_meta['palette'] = 'random'
- def _init_pipeline(self, cfg: ConfigType) -> Compose:
- """Initialize the test pipeline."""
- pipeline_cfg = cfg.test_dataloader.dataset.pipeline
- # For inference, the key of ``img_id`` is not used.
- if 'meta_keys' in pipeline_cfg[-1]:
- pipeline_cfg[-1]['meta_keys'] = tuple(
- meta_key for meta_key in pipeline_cfg[-1]['meta_keys']
- if meta_key != 'img_id')
- load_img_idx = self._get_transform_idx(pipeline_cfg,
- 'LoadImageFromFile')
- if load_img_idx == -1:
- raise ValueError(
- 'LoadImageFromFile is not found in the test pipeline')
- pipeline_cfg[load_img_idx]['type'] = 'mmdet.InferencerLoader'
- return Compose(pipeline_cfg)
- def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int:
- """Returns the index of the transform in a pipeline.
- If the transform is not found, returns -1.
- """
- for i, transform in enumerate(pipeline_cfg):
- if transform['type'] == name:
- return i
- return -1
- def _init_visualizer(self, cfg: ConfigType) -> Optional[Visualizer]:
- """Initialize visualizers.
- Args:
- cfg (ConfigType): Config containing the visualizer information.
- Returns:
- Visualizer or None: Visualizer initialized with config.
- """
- visualizer = super()._init_visualizer(cfg)
- visualizer.dataset_meta = self.model.dataset_meta
- return visualizer
- def _inputs_to_list(self, inputs: InputsType) -> list:
- """Preprocess the inputs to a list.
- Preprocess inputs to a list according to its type:
- - list or tuple: return inputs
- - str:
- - Directory path: return all files in the directory
- - other cases: return a list containing the string. The string
- could be a path to file, a url or other types of string according
- to the task.
- Args:
- inputs (InputsType): Inputs for the inferencer.
- Returns:
- list: List of input for the :meth:`preprocess`.
- """
- if isinstance(inputs, str):
- backend = get_file_backend(inputs)
- if hasattr(backend, 'isdir') and isdir(inputs):
- # Backends like HttpsBackend do not implement `isdir`, so only
- # those backends that implement `isdir` could accept the inputs
- # as a directory
- filename_list = list_dir_or_file(
- inputs, list_dir=False, suffix=IMG_EXTENSIONS)
- inputs = [
- join_path(inputs, filename) for filename in filename_list
- ]
- if not isinstance(inputs, (list, tuple)):
- inputs = [inputs]
- return list(inputs)
- def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs):
- """Process the inputs into a model-feedable format.
- Customize your preprocess by overriding this method. Preprocess should
- return an iterable object, of which each item will be used as the
- input of ``model.test_step``.
- ``BaseInferencer.preprocess`` will return an iterable chunked data,
- which will be used in __call__ like this:
- .. code-block:: python
- def __call__(self, inputs, batch_size=1, **kwargs):
- chunked_data = self.preprocess(inputs, batch_size, **kwargs)
- for batch in chunked_data:
- preds = self.forward(batch, **kwargs)
- Args:
- inputs (InputsType): Inputs given by user.
- batch_size (int): batch size. Defaults to 1.
- Yields:
- Any: Data processed by the ``pipeline`` and ``collate_fn``.
- """
- chunked_data = self._get_chunk_data(inputs, batch_size)
- yield from map(self.collate_fn, chunked_data)
- def _get_chunk_data(self, inputs: Iterable, chunk_size: int):
- """Get batch data from inputs.
- Args:
- inputs (Iterable): An iterable dataset.
- chunk_size (int): Equivalent to batch size.
- Yields:
- list: batch data.
- """
- inputs_iter = iter(inputs)
- while True:
- try:
- chunk_data = []
- for _ in range(chunk_size):
- inputs_ = next(inputs_iter)
- chunk_data.append((inputs_, self.pipeline(inputs_)))
- yield chunk_data
- except StopIteration:
- if chunk_data:
- yield chunk_data
- break
- # TODO: Video and Webcam are currently not supported and
- # may consume too much memory if your input folder has a lot of images.
- # We will be optimized later.
- def __call__(self,
- inputs: InputsType,
- batch_size: int = 1,
- return_vis: bool = False,
- show: bool = False,
- wait_time: int = 0,
- no_save_vis: bool = False,
- draw_pred: bool = True,
- pred_score_thr: float = 0.3,
- return_datasample: bool = False,
- print_result: bool = False,
- no_save_pred: bool = True,
- out_dir: str = '',
- **kwargs) -> dict:
- """Call the inferencer.
- Args:
- inputs (InputsType): Inputs for the inferencer.
- batch_size (int): Inference batch size. Defaults to 1.
- show (bool): Whether to display the visualization results in a
- popup window. Defaults to False.
- wait_time (float): The interval of show (s). Defaults to 0.
- no_save_vis (bool): Whether to force not to save prediction
- vis results. Defaults to False.
- draw_pred (bool): Whether to draw predicted bounding boxes.
- Defaults to True.
- pred_score_thr (float): Minimum score of bboxes to draw.
- Defaults to 0.3.
- return_datasample (bool): Whether to return results as
- :obj:`DetDataSample`. Defaults to False.
- print_result (bool): Whether to print the inference result w/o
- visualization to the console. Defaults to False.
- no_save_pred (bool): Whether to force not to save prediction
- results. Defaults to True.
- out_file: Dir to save the inference results or
- visualization. If left as empty, no file will be saved.
- Defaults to ''.
- **kwargs: Other keyword arguments passed to :meth:`preprocess`,
- :meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
- Each key in kwargs should be in the corresponding set of
- ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
- and ``postprocess_kwargs``.
- Returns:
- dict: Inference and visualization results.
- """
- (
- preprocess_kwargs,
- forward_kwargs,
- visualize_kwargs,
- postprocess_kwargs,
- ) = self._dispatch_kwargs(**kwargs)
- ori_inputs = self._inputs_to_list(inputs)
- inputs = self.preprocess(
- ori_inputs, batch_size=batch_size, **preprocess_kwargs)
- results_dict = {'predictions': [], 'visualization': []}
- for ori_inputs, data in track(inputs, description='Inference'):
- preds = self.forward(data, **forward_kwargs)
- visualization = self.visualize(
- ori_inputs,
- preds,
- return_vis=return_vis,
- show=show,
- wait_time=wait_time,
- draw_pred=draw_pred,
- pred_score_thr=pred_score_thr,
- no_save_vis=no_save_vis,
- img_out_dir=out_dir,
- **visualize_kwargs)
- results = self.postprocess(
- preds,
- visualization,
- return_datasample=return_datasample,
- print_result=print_result,
- no_save_pred=no_save_pred,
- pred_out_dir=out_dir,
- **postprocess_kwargs)
- results_dict['predictions'].extend(results['predictions'])
- if results['visualization'] is not None:
- results_dict['visualization'].extend(results['visualization'])
- return results_dict
- def visualize(self,
- inputs: InputsType,
- preds: PredType,
- return_vis: bool = False,
- show: bool = False,
- wait_time: int = 0,
- draw_pred: bool = True,
- pred_score_thr: float = 0.3,
- no_save_vis: bool = False,
- img_out_dir: str = '',
- **kwargs) -> Union[List[np.ndarray], None]:
- """Visualize predictions.
- Args:
- inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer.
- preds (List[:obj:`DetDataSample`]): Predictions of the model.
- return_vis (bool): Whether to return the visualization result.
- Defaults to False.
- show (bool): Whether to display the image in a popup window.
- Defaults to False.
- wait_time (float): The interval of show (s). Defaults to 0.
- draw_pred (bool): Whether to draw predicted bounding boxes.
- Defaults to True.
- pred_score_thr (float): Minimum score of bboxes to draw.
- Defaults to 0.3.
- no_save_vis (bool): Whether to force not to save prediction
- vis results. Defaults to False.
- img_out_dir (str): Output directory of visualization results.
- If left as empty, no file will be saved. Defaults to ''.
- Returns:
- List[np.ndarray] or None: Returns visualization results only if
- applicable.
- """
- if no_save_vis is True:
- img_out_dir = ''
- if not show and img_out_dir == '' and not return_vis:
- return None
- if self.visualizer is None:
- raise ValueError('Visualization needs the "visualizer" term'
- 'defined in the config, but got None.')
- results = []
- for single_input, pred in zip(inputs, preds):
- if isinstance(single_input, str):
- img_bytes = mmengine.fileio.get(single_input)
- img = mmcv.imfrombytes(img_bytes)
- img = img[:, :, ::-1]
- img_name = osp.basename(single_input)
- elif isinstance(single_input, np.ndarray):
- img = single_input.copy()
- img_num = str(self.num_visualized_imgs).zfill(8)
- img_name = f'{img_num}.jpg'
- else:
- raise ValueError('Unsupported input type: '
- f'{type(single_input)}')
- out_file = osp.join(img_out_dir, 'vis',
- img_name) if img_out_dir != '' else None
- self.visualizer.add_datasample(
- img_name,
- img,
- pred,
- show=show,
- wait_time=wait_time,
- draw_gt=False,
- draw_pred=draw_pred,
- pred_score_thr=pred_score_thr,
- out_file=out_file,
- )
- results.append(self.visualizer.get_image())
- self.num_visualized_imgs += 1
- return results
- def postprocess(
- self,
- preds: PredType,
- visualization: Optional[List[np.ndarray]] = None,
- return_datasample: bool = False,
- print_result: bool = False,
- no_save_pred: bool = False,
- pred_out_dir: str = '',
- **kwargs,
- ) -> Dict:
- """Process the predictions and visualization results from ``forward``
- and ``visualize``.
- This method should be responsible for the following tasks:
- 1. Convert datasamples into a json-serializable dict if needed.
- 2. Pack the predictions and visualization results and return them.
- 3. Dump or log the predictions.
- Args:
- preds (List[:obj:`DetDataSample`]): Predictions of the model.
- visualization (Optional[np.ndarray]): Visualized predictions.
- return_datasample (bool): Whether to use Datasample to store
- inference results. If False, dict will be used.
- print_result (bool): Whether to print the inference result w/o
- visualization to the console. Defaults to False.
- no_save_pred (bool): Whether to force not to save prediction
- results. Defaults to False.
- pred_out_dir: Dir to save the inference results w/o
- visualization. If left as empty, no file will be saved.
- Defaults to ''.
- Returns:
- dict: Inference and visualization results with key ``predictions``
- and ``visualization``.
- - ``visualization`` (Any): Returned by :meth:`visualize`.
- - ``predictions`` (dict or DataSample): Returned by
- :meth:`forward` and processed in :meth:`postprocess`.
- If ``return_datasample=False``, it usually should be a
- json-serializable dict containing only basic data elements such
- as strings and numbers.
- """
- if no_save_pred is True:
- pred_out_dir = ''
- result_dict = {}
- results = preds
- if not return_datasample:
- results = []
- for pred in preds:
- result = self.pred2dict(pred, pred_out_dir)
- results.append(result)
- elif pred_out_dir != '':
- warnings.warn('Currently does not support saving datasample '
- 'when return_datasample is set to True. '
- 'Prediction results are not saved!')
- # Add img to the results after printing and dumping
- result_dict['predictions'] = results
- if print_result:
- print(result_dict)
- result_dict['visualization'] = visualization
- return result_dict
- # TODO: The data format and fields saved in json need further discussion.
- # Maybe should include model name, timestamp, filename, image info etc.
- def pred2dict(self,
- data_sample: DetDataSample,
- pred_out_dir: str = '') -> Dict:
- """Extract elements necessary to represent a prediction into a
- dictionary.
- It's better to contain only basic data elements such as strings and
- numbers in order to guarantee it's json-serializable.
- Args:
- data_sample (:obj:`DetDataSample`): Predictions of the model.
- pred_out_dir: Dir to save the inference results w/o
- visualization. If left as empty, no file will be saved.
- Defaults to ''.
- Returns:
- dict: Prediction results.
- """
- is_save_pred = True
- if pred_out_dir == '':
- is_save_pred = False
- if is_save_pred and 'img_path' in data_sample:
- img_path = osp.basename(data_sample.img_path)
- img_path = osp.splitext(img_path)[0]
- out_img_path = osp.join(pred_out_dir, 'preds',
- img_path + '_panoptic_seg.png')
- out_json_path = osp.join(pred_out_dir, 'preds', img_path + '.json')
- elif is_save_pred:
- out_img_path = osp.join(
- pred_out_dir, 'preds',
- f'{self.num_predicted_imgs}_panoptic_seg.png')
- out_json_path = osp.join(pred_out_dir, 'preds',
- f'{self.num_predicted_imgs}.json')
- self.num_predicted_imgs += 1
- result = {}
- if 'pred_instances' in data_sample:
- masks = data_sample.pred_instances.get('masks')
- pred_instances = data_sample.pred_instances.numpy()
- result = {
- 'bboxes': pred_instances.bboxes.tolist(),
- 'labels': pred_instances.labels.tolist(),
- 'scores': pred_instances.scores.tolist()
- }
- if masks is not None:
- if pred_instances.bboxes.sum() == 0:
- # Fake bbox, such as the SOLO.
- bboxes = mask2bbox(masks.cpu()).numpy().tolist()
- result['bboxes'] = bboxes
- encode_masks = encode_mask_results(pred_instances.masks)
- for encode_mask in encode_masks:
- if isinstance(encode_mask['counts'], bytes):
- encode_mask['counts'] = encode_mask['counts'].decode()
- result['masks'] = encode_masks
- if 'pred_panoptic_seg' in data_sample:
- if VOID is None:
- raise RuntimeError(
- 'panopticapi is not installed, please install it by: '
- 'pip install git+https://github.com/cocodataset/'
- 'panopticapi.git.')
- pan = data_sample.pred_panoptic_seg.sem_seg.cpu().numpy()[0]
- pan[pan % INSTANCE_OFFSET == len(
- self.model.dataset_meta['classes'])] = VOID
- pan = id2rgb(pan).astype(np.uint8)
- if is_save_pred:
- mmcv.imwrite(pan[:, :, ::-1], out_img_path)
- result['panoptic_seg_path'] = out_img_path
- else:
- result['panoptic_seg'] = pan
- if is_save_pred:
- mmengine.dump(result, out_json_path)
- return result
|