123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List, Sequence, Tuple
- import torch
- from mmengine.structures import InstanceData
- from torch import Tensor
- from mmdet.models.roi_heads import CascadeRoIHead
- from mmdet.models.task_modules.samplers import SamplingResult
- from mmdet.models.test_time_augs import merge_aug_masks
- from mmdet.models.utils.misc import empty_instances
- from mmdet.registry import MODELS
- from mmdet.structures import SampleList
- from mmdet.structures.bbox import bbox2roi, get_box_tensor
- from mmdet.utils import ConfigType, InstanceList, MultiConfig
- @MODELS.register_module(force=True) # avoid bug
- class DeticRoIHead(CascadeRoIHead):
- def init_mask_head(self, mask_roi_extractor: MultiConfig,
- mask_head: MultiConfig) -> None:
- """Initialize mask head and mask roi extractor.
- Args:
- mask_head (dict): Config of mask in mask head.
- mask_roi_extractor (:obj:`ConfigDict`, dict or list):
- Config of mask roi extractor.
- """
- self.mask_head = MODELS.build(mask_head)
- if mask_roi_extractor is not None:
- self.share_roi_extractor = False
- self.mask_roi_extractor = MODELS.build(mask_roi_extractor)
- else:
- self.share_roi_extractor = True
- self.mask_roi_extractor = self.bbox_roi_extractor
- def _refine_roi(self, x: Tuple[Tensor], rois: Tensor,
- batch_img_metas: List[dict],
- num_proposals_per_img: Sequence[int], **kwargs) -> tuple:
- """Multi-stage refinement of RoI.
- Args:
- x (tuple[Tensor]): List of multi-level img features.
- rois (Tensor): shape (n, 5), [batch_ind, x1, y1, x2, y2]
- batch_img_metas (list[dict]): List of image information.
- num_proposals_per_img (sequence[int]): number of proposals
- in each image.
- Returns:
- tuple:
- - rois (Tensor): Refined RoI.
- - cls_scores (list[Tensor]): Average predicted
- cls score per image.
- - bbox_preds (list[Tensor]): Bbox branch predictions
- for the last stage of per image.
- """
- # "ms" in variable names means multi-stage
- ms_scores = []
- for stage in range(self.num_stages):
- bbox_results = self._bbox_forward(
- stage=stage, x=x, rois=rois, **kwargs)
- # split batch bbox prediction back to each image
- cls_scores = bbox_results['cls_score'].sigmoid()
- bbox_preds = bbox_results['bbox_pred']
- rois = rois.split(num_proposals_per_img, 0)
- cls_scores = cls_scores.split(num_proposals_per_img, 0)
- ms_scores.append(cls_scores)
- bbox_preds = bbox_preds.split(num_proposals_per_img, 0)
- if stage < self.num_stages - 1:
- bbox_head = self.bbox_head[stage]
- refine_rois_list = []
- for i in range(len(batch_img_metas)):
- if rois[i].shape[0] > 0:
- bbox_label = cls_scores[i][:, :-1].argmax(dim=1)
- # Refactor `bbox_head.regress_by_class` to only accept
- # box tensor without img_idx concatenated.
- refined_bboxes = bbox_head.regress_by_class(
- rois[i][:, 1:], bbox_label, bbox_preds[i],
- batch_img_metas[i])
- refined_bboxes = get_box_tensor(refined_bboxes)
- refined_rois = torch.cat(
- [rois[i][:, [0]], refined_bboxes], dim=1)
- refine_rois_list.append(refined_rois)
- rois = torch.cat(refine_rois_list)
- # ms_scores aligned
- # average scores of each image by stages
- cls_scores = [
- sum([score[i] for score in ms_scores]) / float(len(ms_scores))
- for i in range(len(batch_img_metas))
- ] # aligned
- return rois, cls_scores, bbox_preds
- def _bbox_forward(self, stage: int, x: Tuple[Tensor],
- rois: Tensor) -> dict:
- """Box head forward function used in both training and testing.
- Args:
- stage (int): The current stage in Cascade RoI Head.
- x (tuple[Tensor]): List of multi-level img features.
- rois (Tensor): RoIs with the shape (n, 5) where the first
- column indicates batch id of each RoI.
- Returns:
- dict[str, Tensor]: Usually returns a dictionary with keys:
- - `cls_score` (Tensor): Classification scores.
- - `bbox_pred` (Tensor): Box energies / deltas.
- - `bbox_feats` (Tensor): Extract bbox RoI features.
- """
- bbox_roi_extractor = self.bbox_roi_extractor[stage]
- bbox_head = self.bbox_head[stage]
- bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs],
- rois)
- # do not support caffe_c4 model anymore
- cls_score, bbox_pred = bbox_head(bbox_feats)
- bbox_results = dict(
- cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats)
- return bbox_results
- def predict_bbox(self,
- x: Tuple[Tensor],
- batch_img_metas: List[dict],
- rpn_results_list: InstanceList,
- rcnn_test_cfg: ConfigType,
- rescale: bool = False,
- **kwargs) -> InstanceList:
- """Perform forward propagation of the bbox head and predict detection
- results on the features of the upstream network.
- Args:
- x (tuple[Tensor]): Feature maps of all scale level.
- batch_img_metas (list[dict]): List of image information.
- rpn_results_list (list[:obj:`InstanceData`]): List of region
- proposals.
- rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
- rescale (bool): If True, return boxes in original image space.
- Defaults to False.
- Returns:
- list[:obj:`InstanceData`]: Detection results of each image
- after the post process.
- Each item usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- """
- proposals = [res.bboxes for res in rpn_results_list]
- proposal_scores = [res.scores for res in rpn_results_list]
- num_proposals_per_img = tuple(len(p) for p in proposals)
- rois = bbox2roi(proposals)
- if rois.shape[0] == 0:
- return empty_instances(
- batch_img_metas,
- rois.device,
- task_type='bbox',
- box_type=self.bbox_head[-1].predict_box_type,
- num_classes=self.bbox_head[-1].num_classes,
- score_per_cls=rcnn_test_cfg is None)
- # rois aligned
- rois, cls_scores, bbox_preds = self._refine_roi(
- x=x,
- rois=rois,
- batch_img_metas=batch_img_metas,
- num_proposals_per_img=num_proposals_per_img,
- **kwargs)
- # score reweighting in centernet2
- cls_scores = [(s * ps[:, None])**0.5
- for s, ps in zip(cls_scores, proposal_scores)]
- cls_scores = [
- s * (s == s[:, :-1].max(dim=1)[0][:, None]).float()
- for s in cls_scores
- ]
- # fast_rcnn_inference
- results_list = self.bbox_head[-1].predict_by_feat(
- rois=rois,
- cls_scores=cls_scores,
- bbox_preds=bbox_preds,
- batch_img_metas=batch_img_metas,
- rescale=rescale,
- rcnn_test_cfg=rcnn_test_cfg)
- return results_list
- def _mask_forward(self, x: Tuple[Tensor], rois: Tensor) -> dict:
- """Mask head forward function used in both training and testing.
- Args:
- stage (int): The current stage in Cascade RoI Head.
- x (tuple[Tensor]): Tuple of multi-level img features.
- rois (Tensor): RoIs with the shape (n, 5) where the first
- column indicates batch id of each RoI.
- Returns:
- dict: Usually returns a dictionary with keys:
- - `mask_preds` (Tensor): Mask prediction.
- """
- mask_feats = self.mask_roi_extractor(
- x[:self.mask_roi_extractor.num_inputs], rois)
- # do not support caffe_c4 model anymore
- mask_preds = self.mask_head(mask_feats)
- mask_results = dict(mask_preds=mask_preds)
- return mask_results
- def mask_loss(self, x, sampling_results: List[SamplingResult],
- batch_gt_instances: InstanceList) -> dict:
- """Run forward function and calculate loss for mask head in training.
- Args:
- x (tuple[Tensor]): Tuple of multi-level img features.
- sampling_results (list["obj:`SamplingResult`]): Sampling results.
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes``, ``labels``, and
- ``masks`` attributes.
- Returns:
- dict: Usually returns a dictionary with keys:
- - `mask_preds` (Tensor): Mask prediction.
- - `loss_mask` (dict): A dictionary of mask loss components.
- """
- pos_rois = bbox2roi([res.pos_priors for res in sampling_results])
- mask_results = self._mask_forward(x, pos_rois)
- mask_loss_and_target = self.mask_head.loss_and_target(
- mask_preds=mask_results['mask_preds'],
- sampling_results=sampling_results,
- batch_gt_instances=batch_gt_instances,
- rcnn_train_cfg=self.train_cfg[-1])
- mask_results.update(mask_loss_and_target)
- return mask_results
- def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
- batch_data_samples: SampleList) -> dict:
- """Perform forward propagation and loss calculation of the detection
- roi on the features of the upstream network.
- Args:
- x (tuple[Tensor]): List of multi-level img features.
- rpn_results_list (list[:obj:`InstanceData`]): List of region
- proposals.
- batch_data_samples (list[:obj:`DetDataSample`]): The batch
- data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
- Returns:
- dict[str, Tensor]: A dictionary of loss components
- """
- raise NotImplementedError
- def predict_mask(self,
- x: Tuple[Tensor],
- batch_img_metas: List[dict],
- results_list: List[InstanceData],
- rescale: bool = False) -> List[InstanceData]:
- """Perform forward propagation of the mask head and predict detection
- results on the features of the upstream network.
- Args:
- x (tuple[Tensor]): Feature maps of all scale level.
- batch_img_metas (list[dict]): List of image information.
- results_list (list[:obj:`InstanceData`]): Detection results of
- each image.
- rescale (bool): If True, return boxes in original image space.
- Defaults to False.
- Returns:
- list[:obj:`InstanceData`]: Detection results of each image
- after the post process.
- Each item usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- - masks (Tensor): Has a shape (num_instances, H, W).
- """
- bboxes = [res.bboxes for res in results_list]
- mask_rois = bbox2roi(bboxes)
- if mask_rois.shape[0] == 0:
- results_list = empty_instances(
- batch_img_metas,
- mask_rois.device,
- task_type='mask',
- instance_results=results_list,
- mask_thr_binary=self.test_cfg.mask_thr_binary)
- return results_list
- num_mask_rois_per_img = [len(res) for res in results_list]
- aug_masks = []
- mask_results = self._mask_forward(x, mask_rois)
- mask_preds = mask_results['mask_preds']
- # split batch mask prediction back to each image
- mask_preds = mask_preds.split(num_mask_rois_per_img, 0)
- aug_masks.append([m.sigmoid().detach() for m in mask_preds])
- merged_masks = []
- for i in range(len(batch_img_metas)):
- aug_mask = [mask[i] for mask in aug_masks]
- merged_mask = merge_aug_masks(aug_mask, batch_img_metas[i])
- merged_masks.append(merged_mask)
- results_list = self.mask_head.predict_by_feat(
- mask_preds=merged_masks,
- results_list=results_list,
- batch_img_metas=batch_img_metas,
- rcnn_test_cfg=self.test_cfg,
- rescale=rescale,
- activate_map=True)
- return results_list
|