123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from abc import ABCMeta, abstractmethod
- from typing import Tuple
- from mmengine.model import BaseModule
- from torch import Tensor
- from mmdet.registry import MODELS
- from mmdet.structures import SampleList
- from mmdet.utils import InstanceList, OptConfigType, OptMultiConfig
- class BaseRoIHead(BaseModule, metaclass=ABCMeta):
- """Base class for RoIHeads."""
- def __init__(self,
- bbox_roi_extractor: OptMultiConfig = None,
- bbox_head: OptMultiConfig = None,
- mask_roi_extractor: OptMultiConfig = None,
- mask_head: OptMultiConfig = None,
- shared_head: OptConfigType = None,
- train_cfg: OptConfigType = None,
- test_cfg: OptConfigType = None,
- init_cfg: OptMultiConfig = None) -> None:
- super().__init__(init_cfg=init_cfg)
- self.train_cfg = train_cfg
- self.test_cfg = test_cfg
- if shared_head is not None:
- self.shared_head = MODELS.build(shared_head)
- if bbox_head is not None:
- self.init_bbox_head(bbox_roi_extractor, bbox_head)
- if mask_head is not None:
- self.init_mask_head(mask_roi_extractor, mask_head)
- self.init_assigner_sampler()
- @property
- def with_bbox(self) -> bool:
- """bool: whether the RoI head contains a `bbox_head`"""
- return hasattr(self, 'bbox_head') and self.bbox_head is not None
- @property
- def with_mask(self) -> bool:
- """bool: whether the RoI head contains a `mask_head`"""
- return hasattr(self, 'mask_head') and self.mask_head is not None
- @property
- def with_shared_head(self) -> bool:
- """bool: whether the RoI head contains a `shared_head`"""
- return hasattr(self, 'shared_head') and self.shared_head is not None
- @abstractmethod
- def init_bbox_head(self, *args, **kwargs):
- """Initialize ``bbox_head``"""
- pass
- @abstractmethod
- def init_mask_head(self, *args, **kwargs):
- """Initialize ``mask_head``"""
- pass
- @abstractmethod
- def init_assigner_sampler(self, *args, **kwargs):
- """Initialize assigner and sampler."""
- pass
- @abstractmethod
- def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
- batch_data_samples: SampleList):
- """Perform forward propagation and loss calculation of the roi head on
- the features of the upstream network."""
- def predict(self,
- x: Tuple[Tensor],
- rpn_results_list: InstanceList,
- batch_data_samples: SampleList,
- rescale: bool = False) -> InstanceList:
- """Perform forward propagation of the roi head and predict detection
- results on the features of the upstream network.
- Args:
- x (tuple[Tensor]): Features from upstream network. Each
- has shape (N, C, H, W).
- rpn_results_list (list[:obj:`InstanceData`]): list of region
- proposals.
- batch_data_samples (List[:obj:`DetDataSample`]): The Data
- Samples. It usually includes information such as
- `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
- rescale (bool): Whether to rescale the results to
- the original image. Defaults to True.
- Returns:
- list[obj:`InstanceData`]: Detection results of each image.
- Each item usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- - masks (Tensor): Has a shape (num_instances, H, W).
- """
- assert self.with_bbox, 'Bbox head must be implemented.'
- batch_img_metas = [
- data_samples.metainfo for data_samples in batch_data_samples
- ]
- # TODO: nms_op in mmcv need be enhanced, the bbox result may get
- # difference when not rescale in bbox_head
- # If it has the mask branch, the bbox branch does not need
- # to be scaled to the original image scale, because the mask
- # branch will scale both bbox and mask at the same time.
- bbox_rescale = rescale if not self.with_mask else False
- results_list = self.predict_bbox(
- x,
- batch_img_metas,
- rpn_results_list,
- rcnn_test_cfg=self.test_cfg,
- rescale=bbox_rescale)
- if self.with_mask:
- results_list = self.predict_mask(
- x, batch_img_metas, results_list, rescale=rescale)
- return results_list
|