base.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from abc import ABCMeta, abstractmethod
  3. from typing import Dict, List, Tuple, Union
  4. import torch
  5. from mmengine.model import BaseModel
  6. from torch import Tensor
  7. from mmdet.structures import DetDataSample, OptSampleList, SampleList
  8. from mmdet.utils import InstanceList, OptConfigType, OptMultiConfig
  9. from ..utils import samplelist_boxtype2tensor
  10. ForwardResults = Union[Dict[str, torch.Tensor], List[DetDataSample],
  11. Tuple[torch.Tensor], torch.Tensor]
  12. class BaseDetector(BaseModel, metaclass=ABCMeta):
  13. """Base class for detectors.
  14. Args:
  15. data_preprocessor (dict or ConfigDict, optional): The pre-process
  16. config of :class:`BaseDataPreprocessor`. it usually includes,
  17. ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
  18. init_cfg (dict or ConfigDict, optional): the config to control the
  19. initialization. Defaults to None.
  20. """
  21. def __init__(self,
  22. data_preprocessor: OptConfigType = None,
  23. init_cfg: OptMultiConfig = None):
  24. super().__init__(
  25. data_preprocessor=data_preprocessor, init_cfg=init_cfg)
  26. @property
  27. def with_neck(self) -> bool:
  28. """bool: whether the detector has a neck"""
  29. return hasattr(self, 'neck') and self.neck is not None
  30. # TODO: these properties need to be carefully handled
  31. # for both single stage & two stage detectors
  32. @property
  33. def with_shared_head(self) -> bool:
  34. """bool: whether the detector has a shared head in the RoI Head"""
  35. return hasattr(self, 'roi_head') and self.roi_head.with_shared_head
  36. @property
  37. def with_bbox(self) -> bool:
  38. """bool: whether the detector has a bbox head"""
  39. return ((hasattr(self, 'roi_head') and self.roi_head.with_bbox)
  40. or (hasattr(self, 'bbox_head') and self.bbox_head is not None))
  41. @property
  42. def with_mask(self) -> bool:
  43. """bool: whether the detector has a mask head"""
  44. return ((hasattr(self, 'roi_head') and self.roi_head.with_mask)
  45. or (hasattr(self, 'mask_head') and self.mask_head is not None))
  46. def forward(self,
  47. inputs: torch.Tensor,
  48. data_samples: OptSampleList = None,
  49. mode: str = 'tensor') -> ForwardResults:
  50. """The unified entry for a forward process in both training and test.
  51. The method should accept three modes: "tensor", "predict" and "loss":
  52. - "tensor": Forward the whole network and return tensor or tuple of
  53. tensor without any post-processing, same as a common nn.Module.
  54. - "predict": Forward and return the predictions, which are fully
  55. processed to a list of :obj:`DetDataSample`.
  56. - "loss": Forward and return a dict of losses according to the given
  57. inputs and data samples.
  58. Note that this method doesn't handle either back propagation or
  59. parameter update, which are supposed to be done in :meth:`train_step`.
  60. Args:
  61. inputs (torch.Tensor): The input tensor with shape
  62. (N, C, ...) in general.
  63. data_samples (list[:obj:`DetDataSample`], optional): A batch of
  64. data samples that contain annotations and predictions.
  65. Defaults to None.
  66. mode (str): Return what kind of value. Defaults to 'tensor'.
  67. Returns:
  68. The return type depends on ``mode``.
  69. - If ``mode="tensor"``, return a tensor or a tuple of tensor.
  70. - If ``mode="predict"``, return a list of :obj:`DetDataSample`.
  71. - If ``mode="loss"``, return a dict of tensor.
  72. """
  73. if mode == 'loss':
  74. return self.loss(inputs, data_samples)
  75. elif mode == 'predict':
  76. return self.predict(inputs, data_samples)
  77. elif mode == 'tensor':
  78. return self._forward(inputs, data_samples)
  79. else:
  80. raise RuntimeError(f'Invalid mode "{mode}". '
  81. 'Only supports loss, predict and tensor mode')
  82. @abstractmethod
  83. def loss(self, batch_inputs: Tensor,
  84. batch_data_samples: SampleList) -> Union[dict, tuple]:
  85. """Calculate losses from a batch of inputs and data samples."""
  86. pass
  87. @abstractmethod
  88. def predict(self, batch_inputs: Tensor,
  89. batch_data_samples: SampleList) -> SampleList:
  90. """Predict results from a batch of inputs and data samples with post-
  91. processing."""
  92. pass
  93. @abstractmethod
  94. def _forward(self,
  95. batch_inputs: Tensor,
  96. batch_data_samples: OptSampleList = None):
  97. """Network forward process.
  98. Usually includes backbone, neck and head forward without any post-
  99. processing.
  100. """
  101. pass
  102. @abstractmethod
  103. def extract_feat(self, batch_inputs: Tensor):
  104. """Extract features from images."""
  105. pass
  106. def add_pred_to_datasample(self, data_samples: SampleList,
  107. results_list: InstanceList) -> SampleList:
  108. """Add predictions to `DetDataSample`.
  109. Args:
  110. data_samples (list[:obj:`DetDataSample`], optional): A batch of
  111. data samples that contain annotations and predictions.
  112. results_list (list[:obj:`InstanceData`]): Detection results of
  113. each image.
  114. Returns:
  115. list[:obj:`DetDataSample`]: Detection results of the
  116. input images. Each DetDataSample usually contain
  117. 'pred_instances'. And the ``pred_instances`` usually
  118. contains following keys.
  119. - scores (Tensor): Classification scores, has a shape
  120. (num_instance, )
  121. - labels (Tensor): Labels of bboxes, has a shape
  122. (num_instances, ).
  123. - bboxes (Tensor): Has a shape (num_instances, 4),
  124. the last dimension 4 arrange as (x1, y1, x2, y2).
  125. """
  126. for data_sample, pred_instances in zip(data_samples, results_list):
  127. data_sample.pred_instances = pred_instances
  128. samplelist_boxtype2tensor(data_samples)
  129. return data_samples