123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import copy
- import warnings
- from pathlib import Path
- from typing import Optional, Sequence, Union
- import numpy as np
- import torch
- import torch.nn as nn
- from mmcv.ops import RoIPool
- from mmcv.transforms import Compose
- from mmengine.config import Config
- from mmengine.model.utils import revert_sync_batchnorm
- from mmengine.registry import init_default_scope
- from mmengine.runner import load_checkpoint
- from mmdet.registry import DATASETS
- from ..evaluation import get_classes
- from ..registry import MODELS
- from ..structures import DetDataSample, SampleList
- from ..utils import get_test_pipeline_cfg
- def init_detector(
- config: Union[str, Path, Config],
- checkpoint: Optional[str] = None,
- palette: str = 'none',
- device: str = 'cuda:0',
- cfg_options: Optional[dict] = None,
- ) -> nn.Module:
- """Initialize a detector from config file.
- Args:
- config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
- :obj:`Path`, or the config object.
- checkpoint (str, optional): Checkpoint path. If left as None, the model
- will not load any weights.
- palette (str): Color palette used for visualization. If palette
- is stored in checkpoint, use checkpoint's palette first, otherwise
- use externally passed palette. Currently, supports 'coco', 'voc',
- 'citys' and 'random'. Defaults to none.
- device (str): The device where the anchors will be put on.
- Defaults to cuda:0.
- cfg_options (dict, optional): Options to override some settings in
- the used config.
- Returns:
- nn.Module: The constructed detector.
- """
- if isinstance(config, (str, Path)):
- config = Config.fromfile(config)
- elif not isinstance(config, Config):
- raise TypeError('config must be a filename or Config object, '
- f'but got {type(config)}')
- if cfg_options is not None:
- config.merge_from_dict(cfg_options)
- elif 'init_cfg' in config.model.backbone:
- config.model.backbone.init_cfg = None
- init_default_scope(config.get('default_scope', 'mmdet'))
- model = MODELS.build(config.model)
- model = revert_sync_batchnorm(model)
- if checkpoint is None:
- warnings.simplefilter('once')
- warnings.warn('checkpoint is None, use COCO classes by default.')
- model.dataset_meta = {'classes': get_classes('coco')}
- else:
- checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
- # Weights converted from elsewhere may not have meta fields.
- checkpoint_meta = checkpoint.get('meta', {})
- # save the dataset_meta in the model for convenience
- if 'dataset_meta' in checkpoint_meta:
- # mmdet 3.x, all keys should be lowercase
- model.dataset_meta = {
- k.lower(): v
- for k, v in checkpoint_meta['dataset_meta'].items()
- }
- elif 'CLASSES' in checkpoint_meta:
- # < mmdet 3.x
- classes = checkpoint_meta['CLASSES']
- model.dataset_meta = {'classes': classes}
- else:
- warnings.simplefilter('once')
- warnings.warn(
- 'dataset_meta or class names are not saved in the '
- 'checkpoint\'s meta data, use COCO classes by default.')
- model.dataset_meta = {'classes': get_classes('coco')}
- # Priority: args.palette -> config -> checkpoint
- if palette != 'none':
- model.dataset_meta['palette'] = palette
- else:
- test_dataset_cfg = copy.deepcopy(config.test_dataloader.dataset)
- # lazy init. We only need the metainfo.
- test_dataset_cfg['lazy_init'] = True
- metainfo = DATASETS.build(test_dataset_cfg).metainfo
- cfg_palette = metainfo.get('palette', None)
- if cfg_palette is not None:
- model.dataset_meta['palette'] = cfg_palette
- else:
- if 'palette' not in model.dataset_meta:
- warnings.warn(
- 'palette does not exist, random is used by default. '
- 'You can also set the palette to customize.')
- model.dataset_meta['palette'] = 'random'
- model.cfg = config # save the config in the model for convenience
- model.to(device)
- model.eval()
- return model
- ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
- def inference_detector(
- model: nn.Module,
- imgs: ImagesType,
- test_pipeline: Optional[Compose] = None
- ) -> Union[DetDataSample, SampleList]:
- """Inference image(s) with the detector.
- Args:
- model (nn.Module): The loaded detector.
- imgs (str, ndarray, Sequence[str/ndarray]):
- Either image files or loaded images.
- test_pipeline (:obj:`Compose`): Test pipeline.
- Returns:
- :obj:`DetDataSample` or list[:obj:`DetDataSample`]:
- If imgs is a list or tuple, the same length list type results
- will be returned, otherwise return the detection results directly.
- """
- if isinstance(imgs, (list, tuple)):
- is_batch = True
- else:
- imgs = [imgs]
- is_batch = False
- cfg = model.cfg
- if test_pipeline is None:
- cfg = cfg.copy()
- test_pipeline = get_test_pipeline_cfg(cfg)
- if isinstance(imgs[0], np.ndarray):
- # Calling this method across libraries will result
- # in module unregistered error if not prefixed with mmdet.
- test_pipeline[0].type = 'mmdet.LoadImageFromNDArray'
- test_pipeline = Compose(test_pipeline)
- if model.data_preprocessor.device.type == 'cpu':
- for m in model.modules():
- assert not isinstance(
- m, RoIPool
- ), 'CPU inference with RoIPool is not supported currently.'
- result_list = []
- for img in imgs:
- # prepare data
- if isinstance(img, np.ndarray):
- # TODO: remove img_id.
- data_ = dict(img=img, img_id=0)
- else:
- # TODO: remove img_id.
- data_ = dict(img_path=img, img_id=0)
- # build the data pipeline
- data_ = test_pipeline(data_)
- data_['inputs'] = [data_['inputs']]
- data_['data_samples'] = [data_['data_samples']]
- # forward the model
- with torch.no_grad():
- results = model.test_step(data_)[0]
- result_list.append(results)
- if not is_batch:
- return result_list[0]
- else:
- return result_list
- # TODO: Awaiting refactoring
- async def async_inference_detector(model, imgs):
- """Async inference image(s) with the detector.
- Args:
- model (nn.Module): The loaded detector.
- img (str | ndarray): Either image files or loaded images.
- Returns:
- Awaitable detection results.
- """
- if not isinstance(imgs, (list, tuple)):
- imgs = [imgs]
- cfg = model.cfg
- if isinstance(imgs[0], np.ndarray):
- cfg = cfg.copy()
- # set loading pipeline type
- cfg.data.test.pipeline[0].type = 'LoadImageFromNDArray'
- # cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
- test_pipeline = Compose(cfg.data.test.pipeline)
- datas = []
- for img in imgs:
- # prepare data
- if isinstance(img, np.ndarray):
- # directly add img
- data = dict(img=img)
- else:
- # add information into dict
- data = dict(img_info=dict(filename=img), img_prefix=None)
- # build the data pipeline
- data = test_pipeline(data)
- datas.append(data)
- for m in model.modules():
- assert not isinstance(
- m,
- RoIPool), 'CPU inference with RoIPool is not supported currently.'
- # We don't restore `torch.is_grad_enabled()` value during concurrent
- # inference since execution can overlap
- torch.set_grad_enabled(False)
- results = await model.aforward_test(data, rescale=True)
- return results
|