1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Tuple
- from torch import Tensor
- from mmdet.registry import MODELS
- from .standard_roi_head import StandardRoIHead
- @MODELS.register_module()
- class DoubleHeadRoIHead(StandardRoIHead):
- """RoI head for `Double Head RCNN <https://arxiv.org/abs/1904.06493>`_.
- Args:
- reg_roi_scale_factor (float): The scale factor to extend the rois
- used to extract the regression features.
- """
- def __init__(self, reg_roi_scale_factor: float, **kwargs):
- super().__init__(**kwargs)
- self.reg_roi_scale_factor = reg_roi_scale_factor
- def _bbox_forward(self, x: Tuple[Tensor], rois: Tensor) -> dict:
- """Box head forward function used in both training and testing.
- Args:
- x (tuple[Tensor]): List of multi-level img features.
- rois (Tensor): RoIs with the shape (n, 5) where the first
- column indicates batch id of each RoI.
- Returns:
- dict[str, Tensor]: Usually returns a dictionary with keys:
- - `cls_score` (Tensor): Classification scores.
- - `bbox_pred` (Tensor): Box energies / deltas.
- - `bbox_feats` (Tensor): Extract bbox RoI features.
- """
- bbox_cls_feats = self.bbox_roi_extractor(
- x[:self.bbox_roi_extractor.num_inputs], rois)
- bbox_reg_feats = self.bbox_roi_extractor(
- x[:self.bbox_roi_extractor.num_inputs],
- rois,
- roi_scale_factor=self.reg_roi_scale_factor)
- if self.with_shared_head:
- bbox_cls_feats = self.shared_head(bbox_cls_feats)
- bbox_reg_feats = self.shared_head(bbox_reg_feats)
- cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats)
- bbox_results = dict(
- cls_score=cls_score,
- bbox_pred=bbox_pred,
- bbox_feats=bbox_cls_feats)
- return bbox_results
|