double_roi_head.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Tuple
  3. from torch import Tensor
  4. from mmdet.registry import MODELS
  5. from .standard_roi_head import StandardRoIHead
  6. @MODELS.register_module()
  7. class DoubleHeadRoIHead(StandardRoIHead):
  8. """RoI head for `Double Head RCNN <https://arxiv.org/abs/1904.06493>`_.
  9. Args:
  10. reg_roi_scale_factor (float): The scale factor to extend the rois
  11. used to extract the regression features.
  12. """
  13. def __init__(self, reg_roi_scale_factor: float, **kwargs):
  14. super().__init__(**kwargs)
  15. self.reg_roi_scale_factor = reg_roi_scale_factor
  16. def _bbox_forward(self, x: Tuple[Tensor], rois: Tensor) -> dict:
  17. """Box head forward function used in both training and testing.
  18. Args:
  19. x (tuple[Tensor]): List of multi-level img features.
  20. rois (Tensor): RoIs with the shape (n, 5) where the first
  21. column indicates batch id of each RoI.
  22. Returns:
  23. dict[str, Tensor]: Usually returns a dictionary with keys:
  24. - `cls_score` (Tensor): Classification scores.
  25. - `bbox_pred` (Tensor): Box energies / deltas.
  26. - `bbox_feats` (Tensor): Extract bbox RoI features.
  27. """
  28. bbox_cls_feats = self.bbox_roi_extractor(
  29. x[:self.bbox_roi_extractor.num_inputs], rois)
  30. bbox_reg_feats = self.bbox_roi_extractor(
  31. x[:self.bbox_roi_extractor.num_inputs],
  32. rois,
  33. roi_scale_factor=self.reg_roi_scale_factor)
  34. if self.with_shared_head:
  35. bbox_cls_feats = self.shared_head(bbox_cls_feats)
  36. bbox_reg_feats = self.shared_head(bbox_reg_feats)
  37. cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats)
  38. bbox_results = dict(
  39. cls_score=cls_score,
  40. bbox_pred=bbox_pred,
  41. bbox_feats=bbox_cls_feats)
  42. return bbox_results