123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from abc import ABCMeta, abstractmethod
- from typing import Dict, List, Tuple, Union
- import torch
- from mmengine.model import BaseModel
- from torch import Tensor
- from mmdet.structures import DetDataSample, OptSampleList, SampleList
- from mmdet.utils import InstanceList, OptConfigType, OptMultiConfig
- from ..utils import samplelist_boxtype2tensor
- ForwardResults = Union[Dict[str, torch.Tensor], List[DetDataSample],
- Tuple[torch.Tensor], torch.Tensor]
- class BaseDetector(BaseModel, metaclass=ABCMeta):
- """Base class for detectors.
- Args:
- data_preprocessor (dict or ConfigDict, optional): The pre-process
- config of :class:`BaseDataPreprocessor`. it usually includes,
- ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
- init_cfg (dict or ConfigDict, optional): the config to control the
- initialization. Defaults to None.
- """
- def __init__(self,
- data_preprocessor: OptConfigType = None,
- init_cfg: OptMultiConfig = None):
- super().__init__(
- data_preprocessor=data_preprocessor, init_cfg=init_cfg)
- @property
- def with_neck(self) -> bool:
- """bool: whether the detector has a neck"""
- return hasattr(self, 'neck') and self.neck is not None
- # TODO: these properties need to be carefully handled
- # for both single stage & two stage detectors
- @property
- def with_shared_head(self) -> bool:
- """bool: whether the detector has a shared head in the RoI Head"""
- return hasattr(self, 'roi_head') and self.roi_head.with_shared_head
- @property
- def with_bbox(self) -> bool:
- """bool: whether the detector has a bbox head"""
- return ((hasattr(self, 'roi_head') and self.roi_head.with_bbox)
- or (hasattr(self, 'bbox_head') and self.bbox_head is not None))
- @property
- def with_mask(self) -> bool:
- """bool: whether the detector has a mask head"""
- return ((hasattr(self, 'roi_head') and self.roi_head.with_mask)
- or (hasattr(self, 'mask_head') and self.mask_head is not None))
- def forward(self,
- inputs: torch.Tensor,
- data_samples: OptSampleList = None,
- mode: str = 'tensor') -> ForwardResults:
- """The unified entry for a forward process in both training and test.
- The method should accept three modes: "tensor", "predict" and "loss":
- - "tensor": Forward the whole network and return tensor or tuple of
- tensor without any post-processing, same as a common nn.Module.
- - "predict": Forward and return the predictions, which are fully
- processed to a list of :obj:`DetDataSample`.
- - "loss": Forward and return a dict of losses according to the given
- inputs and data samples.
- Note that this method doesn't handle either back propagation or
- parameter update, which are supposed to be done in :meth:`train_step`.
- Args:
- inputs (torch.Tensor): The input tensor with shape
- (N, C, ...) in general.
- data_samples (list[:obj:`DetDataSample`], optional): A batch of
- data samples that contain annotations and predictions.
- Defaults to None.
- mode (str): Return what kind of value. Defaults to 'tensor'.
- Returns:
- The return type depends on ``mode``.
- - If ``mode="tensor"``, return a tensor or a tuple of tensor.
- - If ``mode="predict"``, return a list of :obj:`DetDataSample`.
- - If ``mode="loss"``, return a dict of tensor.
- """
- if mode == 'loss':
- return self.loss(inputs, data_samples)
- elif mode == 'predict':
- return self.predict(inputs, data_samples)
- elif mode == 'tensor':
- return self._forward(inputs, data_samples)
- else:
- raise RuntimeError(f'Invalid mode "{mode}". '
- 'Only supports loss, predict and tensor mode')
- @abstractmethod
- def loss(self, batch_inputs: Tensor,
- batch_data_samples: SampleList) -> Union[dict, tuple]:
- """Calculate losses from a batch of inputs and data samples."""
- pass
- @abstractmethod
- def predict(self, batch_inputs: Tensor,
- batch_data_samples: SampleList) -> SampleList:
- """Predict results from a batch of inputs and data samples with post-
- processing."""
- pass
- @abstractmethod
- def _forward(self,
- batch_inputs: Tensor,
- batch_data_samples: OptSampleList = None):
- """Network forward process.
- Usually includes backbone, neck and head forward without any post-
- processing.
- """
- pass
- @abstractmethod
- def extract_feat(self, batch_inputs: Tensor):
- """Extract features from images."""
- pass
- def add_pred_to_datasample(self, data_samples: SampleList,
- results_list: InstanceList) -> SampleList:
- """Add predictions to `DetDataSample`.
- Args:
- data_samples (list[:obj:`DetDataSample`], optional): A batch of
- data samples that contain annotations and predictions.
- results_list (list[:obj:`InstanceData`]): Detection results of
- each image.
- Returns:
- list[:obj:`DetDataSample`]: Detection results of the
- input images. Each DetDataSample usually contain
- 'pred_instances'. And the ``pred_instances`` usually
- contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- """
- for data_sample, pred_instances in zip(data_samples, results_list):
- data_sample.pred_instances = pred_instances
- samplelist_boxtype2tensor(data_samples)
- return data_samples
|