# Copyright (c) OpenMMLab. All rights reserved. from abc import ABCMeta, abstractmethod from typing import Tuple from mmengine.model import BaseModule from torch import Tensor from mmdet.registry import MODELS from mmdet.structures import SampleList from mmdet.utils import InstanceList, OptConfigType, OptMultiConfig class BaseRoIHead(BaseModule, metaclass=ABCMeta): """Base class for RoIHeads.""" def __init__(self, bbox_roi_extractor: OptMultiConfig = None, bbox_head: OptMultiConfig = None, mask_roi_extractor: OptMultiConfig = None, mask_head: OptMultiConfig = None, shared_head: OptConfigType = None, train_cfg: OptConfigType = None, test_cfg: OptConfigType = None, init_cfg: OptMultiConfig = None) -> None: super().__init__(init_cfg=init_cfg) self.train_cfg = train_cfg self.test_cfg = test_cfg if shared_head is not None: self.shared_head = MODELS.build(shared_head) if bbox_head is not None: self.init_bbox_head(bbox_roi_extractor, bbox_head) if mask_head is not None: self.init_mask_head(mask_roi_extractor, mask_head) self.init_assigner_sampler() @property def with_bbox(self) -> bool: """bool: whether the RoI head contains a `bbox_head`""" return hasattr(self, 'bbox_head') and self.bbox_head is not None @property def with_mask(self) -> bool: """bool: whether the RoI head contains a `mask_head`""" return hasattr(self, 'mask_head') and self.mask_head is not None @property def with_shared_head(self) -> bool: """bool: whether the RoI head contains a `shared_head`""" return hasattr(self, 'shared_head') and self.shared_head is not None @abstractmethod def init_bbox_head(self, *args, **kwargs): """Initialize ``bbox_head``""" pass @abstractmethod def init_mask_head(self, *args, **kwargs): """Initialize ``mask_head``""" pass @abstractmethod def init_assigner_sampler(self, *args, **kwargs): """Initialize assigner and sampler.""" pass @abstractmethod def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList, batch_data_samples: SampleList): """Perform forward propagation and loss calculation of the roi head on the features of the upstream network.""" def predict(self, x: Tuple[Tensor], rpn_results_list: InstanceList, batch_data_samples: SampleList, rescale: bool = False) -> InstanceList: """Perform forward propagation of the roi head and predict detection results on the features of the upstream network. Args: x (tuple[Tensor]): Features from upstream network. Each has shape (N, C, H, W). rpn_results_list (list[:obj:`InstanceData`]): list of region proposals. batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. rescale (bool): Whether to rescale the results to the original image. Defaults to True. Returns: list[obj:`InstanceData`]: Detection results of each image. Each item usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). - masks (Tensor): Has a shape (num_instances, H, W). """ assert self.with_bbox, 'Bbox head must be implemented.' batch_img_metas = [ data_samples.metainfo for data_samples in batch_data_samples ] # TODO: nms_op in mmcv need be enhanced, the bbox result may get # difference when not rescale in bbox_head # If it has the mask branch, the bbox branch does not need # to be scaled to the original image scale, because the mask # branch will scale both bbox and mask at the same time. bbox_rescale = rescale if not self.with_mask else False results_list = self.predict_bbox( x, batch_img_metas, rpn_results_list, rcnn_test_cfg=self.test_cfg, rescale=bbox_rescale) if self.with_mask: results_list = self.predict_mask( x, batch_img_metas, results_list, rescale=rescale) return results_list