# Copyright (c) OpenMMLab. All rights reserved. from typing import Tuple from torch import Tensor from mmdet.registry import MODELS from .standard_roi_head import StandardRoIHead @MODELS.register_module() class DoubleHeadRoIHead(StandardRoIHead): """RoI head for `Double Head RCNN `_. Args: reg_roi_scale_factor (float): The scale factor to extend the rois used to extract the regression features. """ def __init__(self, reg_roi_scale_factor: float, **kwargs): super().__init__(**kwargs) self.reg_roi_scale_factor = reg_roi_scale_factor def _bbox_forward(self, x: Tuple[Tensor], rois: Tensor) -> dict: """Box head forward function used in both training and testing. Args: x (tuple[Tensor]): List of multi-level img features. rois (Tensor): RoIs with the shape (n, 5) where the first column indicates batch id of each RoI. Returns: dict[str, Tensor]: Usually returns a dictionary with keys: - `cls_score` (Tensor): Classification scores. - `bbox_pred` (Tensor): Box energies / deltas. - `bbox_feats` (Tensor): Extract bbox RoI features. """ bbox_cls_feats = self.bbox_roi_extractor( x[:self.bbox_roi_extractor.num_inputs], rois) bbox_reg_feats = self.bbox_roi_extractor( x[:self.bbox_roi_extractor.num_inputs], rois, roi_scale_factor=self.reg_roi_scale_factor) if self.with_shared_head: bbox_cls_feats = self.shared_head(bbox_cls_feats) bbox_reg_feats = self.shared_head(bbox_reg_feats) cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats) bbox_results = dict( cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_cls_feats) return bbox_results