123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List, Tuple
- import torch
- from mmengine.config import ConfigDict
- from mmengine.structures import InstanceData
- from torch import Tensor
- from mmdet.models.task_modules.samplers import PseudoSampler
- from mmdet.registry import MODELS
- from mmdet.structures import SampleList
- from mmdet.structures.bbox import bbox2roi
- from mmdet.utils import ConfigType, InstanceList, OptConfigType
- from ..utils.misc import empty_instances, unpack_gt_instances
- from .cascade_roi_head import CascadeRoIHead
- @MODELS.register_module()
- class SparseRoIHead(CascadeRoIHead):
- r"""The RoIHead for `Sparse R-CNN: End-to-End Object Detection with
- Learnable Proposals <https://arxiv.org/abs/2011.12450>`_
- and `Instances as Queries <http://arxiv.org/abs/2105.01928>`_
- Args:
- num_stages (int): Number of stage whole iterative process.
- Defaults to 6.
- stage_loss_weights (Tuple[float]): The loss
- weight of each stage. By default all stages have
- the same weight 1.
- bbox_roi_extractor (:obj:`ConfigDict` or dict): Config of box
- roi extractor.
- mask_roi_extractor (:obj:`ConfigDict` or dict): Config of mask
- roi extractor.
- bbox_head (:obj:`ConfigDict` or dict): Config of box head.
- mask_head (:obj:`ConfigDict` or dict): Config of mask head.
- train_cfg (:obj:`ConfigDict` or dict, Optional): Configuration
- information in train stage. Defaults to None.
- test_cfg (:obj:`ConfigDict` or dict, Optional): Configuration
- information in test stage. Defaults to None.
- init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
- dict]): Initialization config dict. Defaults to None.
- """
- def __init__(self,
- num_stages: int = 6,
- stage_loss_weights: Tuple[float] = (1, 1, 1, 1, 1, 1),
- proposal_feature_channel: int = 256,
- bbox_roi_extractor: ConfigType = dict(
- type='SingleRoIExtractor',
- roi_layer=dict(
- type='RoIAlign', output_size=7, sampling_ratio=2),
- out_channels=256,
- featmap_strides=[4, 8, 16, 32]),
- mask_roi_extractor: OptConfigType = None,
- bbox_head: ConfigType = dict(
- type='DIIHead',
- num_classes=80,
- num_fcs=2,
- num_heads=8,
- num_cls_fcs=1,
- num_reg_fcs=3,
- feedforward_channels=2048,
- hidden_channels=256,
- dropout=0.0,
- roi_feat_size=7,
- ffn_act_cfg=dict(type='ReLU', inplace=True)),
- mask_head: OptConfigType = None,
- train_cfg: OptConfigType = None,
- test_cfg: OptConfigType = None,
- init_cfg: OptConfigType = None) -> None:
- assert bbox_roi_extractor is not None
- assert bbox_head is not None
- assert len(stage_loss_weights) == num_stages
- self.num_stages = num_stages
- self.stage_loss_weights = stage_loss_weights
- self.proposal_feature_channel = proposal_feature_channel
- super().__init__(
- num_stages=num_stages,
- stage_loss_weights=stage_loss_weights,
- bbox_roi_extractor=bbox_roi_extractor,
- mask_roi_extractor=mask_roi_extractor,
- bbox_head=bbox_head,
- mask_head=mask_head,
- train_cfg=train_cfg,
- test_cfg=test_cfg,
- init_cfg=init_cfg)
- # train_cfg would be None when run the test.py
- if train_cfg is not None:
- for stage in range(num_stages):
- assert isinstance(self.bbox_sampler[stage], PseudoSampler), \
- 'Sparse R-CNN and QueryInst only support `PseudoSampler`'
- def bbox_loss(self, stage: int, x: Tuple[Tensor],
- results_list: InstanceList, object_feats: Tensor,
- batch_img_metas: List[dict],
- batch_gt_instances: InstanceList) -> dict:
- """Perform forward propagation and loss calculation of the bbox head on
- the features of the upstream network.
- Args:
- stage (int): The current stage in iterative process.
- x (tuple[Tensor]): List of multi-level img features.
- results_list (List[:obj:`InstanceData`]) : List of region
- proposals.
- object_feats (Tensor): The object feature extracted from
- the previous stage.
- batch_img_metas (list[dict]): Meta information of each image.
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes``, ``labels``, and
- ``masks`` attributes.
- Returns:
- dict[str, Tensor]: Usually returns a dictionary with keys:
- - `cls_score` (Tensor): Classification scores.
- - `bbox_pred` (Tensor): Box energies / deltas.
- - `bbox_feats` (Tensor): Extract bbox RoI features.
- - `loss_bbox` (dict): A dictionary of bbox loss components.
- """
- proposal_list = [res.bboxes for res in results_list]
- rois = bbox2roi(proposal_list)
- bbox_results = self._bbox_forward(stage, x, rois, object_feats,
- batch_img_metas)
- imgs_whwh = torch.cat(
- [res.imgs_whwh[None, ...] for res in results_list])
- cls_pred_list = bbox_results['detached_cls_scores']
- proposal_list = bbox_results['detached_proposals']
- sampling_results = []
- bbox_head = self.bbox_head[stage]
- for i in range(len(batch_img_metas)):
- pred_instances = InstanceData()
- # TODO: Enhance the logic
- pred_instances.bboxes = proposal_list[i] # for assinger
- pred_instances.scores = cls_pred_list[i]
- pred_instances.priors = proposal_list[i] # for sampler
- assign_result = self.bbox_assigner[stage].assign(
- pred_instances=pred_instances,
- gt_instances=batch_gt_instances[i],
- gt_instances_ignore=None,
- img_meta=batch_img_metas[i])
- sampling_result = self.bbox_sampler[stage].sample(
- assign_result, pred_instances, batch_gt_instances[i])
- sampling_results.append(sampling_result)
- bbox_results.update(sampling_results=sampling_results)
- cls_score = bbox_results['cls_score']
- decoded_bboxes = bbox_results['decoded_bboxes']
- cls_score = cls_score.view(-1, cls_score.size(-1))
- decoded_bboxes = decoded_bboxes.view(-1, 4)
- bbox_loss_and_target = bbox_head.loss_and_target(
- cls_score,
- decoded_bboxes,
- sampling_results,
- self.train_cfg[stage],
- imgs_whwh=imgs_whwh,
- concat=True)
- bbox_results.update(bbox_loss_and_target)
- # propose for the new proposal_list
- proposal_list = []
- for idx in range(len(batch_img_metas)):
- results = InstanceData()
- results.imgs_whwh = results_list[idx].imgs_whwh
- results.bboxes = bbox_results['detached_proposals'][idx]
- proposal_list.append(results)
- bbox_results.update(results_list=proposal_list)
- return bbox_results
- def _bbox_forward(self, stage: int, x: Tuple[Tensor], rois: Tensor,
- object_feats: Tensor,
- batch_img_metas: List[dict]) -> dict:
- """Box head forward function used in both training and testing. Returns
- all regression, classification results and a intermediate feature.
- Args:
- stage (int): The current stage in iterative process.
- x (tuple[Tensor]): List of multi-level img features.
- rois (Tensor): RoIs with the shape (n, 5) where the first
- column indicates batch id of each RoI.
- Each dimension means (img_index, x1, y1, x2, y2).
- object_feats (Tensor): The object feature extracted from
- the previous stage.
- batch_img_metas (list[dict]): Meta information of each image.
- Returns:
- dict[str, Tensor]: a dictionary of bbox head outputs,
- Containing the following results:
- - cls_score (Tensor): The score of each class, has
- shape (batch_size, num_proposals, num_classes)
- when use focal loss or
- (batch_size, num_proposals, num_classes+1)
- otherwise.
- - decoded_bboxes (Tensor): The regression results
- with shape (batch_size, num_proposal, 4).
- The last dimension 4 represents
- [tl_x, tl_y, br_x, br_y].
- - object_feats (Tensor): The object feature extracted
- from current stage
- - detached_cls_scores (list[Tensor]): The detached
- classification results, length is batch_size, and
- each tensor has shape (num_proposal, num_classes).
- - detached_proposals (list[tensor]): The detached
- regression results, length is batch_size, and each
- tensor has shape (num_proposal, 4). The last
- dimension 4 represents [tl_x, tl_y, br_x, br_y].
- """
- num_imgs = len(batch_img_metas)
- bbox_roi_extractor = self.bbox_roi_extractor[stage]
- bbox_head = self.bbox_head[stage]
- bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs],
- rois)
- cls_score, bbox_pred, object_feats, attn_feats = bbox_head(
- bbox_feats, object_feats)
- fake_bbox_results = dict(
- rois=rois,
- bbox_targets=(rois.new_zeros(len(rois), dtype=torch.long), None),
- bbox_pred=bbox_pred.view(-1, bbox_pred.size(-1)),
- cls_score=cls_score.view(-1, cls_score.size(-1)))
- fake_sampling_results = [
- InstanceData(pos_is_gt=rois.new_zeros(object_feats.size(1)))
- for _ in range(len(batch_img_metas))
- ]
- results_list = bbox_head.refine_bboxes(
- sampling_results=fake_sampling_results,
- bbox_results=fake_bbox_results,
- batch_img_metas=batch_img_metas)
- proposal_list = [res.bboxes for res in results_list]
- bbox_results = dict(
- cls_score=cls_score,
- decoded_bboxes=torch.cat(proposal_list),
- object_feats=object_feats,
- attn_feats=attn_feats,
- # detach then use it in label assign
- detached_cls_scores=[
- cls_score[i].detach() for i in range(num_imgs)
- ],
- detached_proposals=[item.detach() for item in proposal_list])
- return bbox_results
- def _mask_forward(self, stage: int, x: Tuple[Tensor], rois: Tensor,
- attn_feats) -> dict:
- """Mask head forward function used in both training and testing.
- Args:
- stage (int): The current stage in Cascade RoI Head.
- x (tuple[Tensor]): Tuple of multi-level img features.
- rois (Tensor): RoIs with the shape (n, 5) where the first
- column indicates batch id of each RoI.
- attn_feats (Tensot): Intermediate feature get from the last
- diihead, has shape
- (batch_size*num_proposals, feature_dimensions)
- Returns:
- dict: Usually returns a dictionary with keys:
- - `mask_preds` (Tensor): Mask prediction.
- """
- mask_roi_extractor = self.mask_roi_extractor[stage]
- mask_head = self.mask_head[stage]
- mask_feats = mask_roi_extractor(x[:mask_roi_extractor.num_inputs],
- rois)
- # do not support caffe_c4 model anymore
- mask_preds = mask_head(mask_feats, attn_feats)
- mask_results = dict(mask_preds=mask_preds)
- return mask_results
- def mask_loss(self, stage: int, x: Tuple[Tensor], bbox_results: dict,
- batch_gt_instances: InstanceList,
- rcnn_train_cfg: ConfigDict) -> dict:
- """Run forward function and calculate loss for mask head in training.
- Args:
- stage (int): The current stage in Cascade RoI Head.
- x (tuple[Tensor]): Tuple of multi-level img features.
- bbox_results (dict): Results obtained from `bbox_loss`.
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes``, ``labels``, and
- ``masks`` attributes.
- rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
- Returns:
- dict: Usually returns a dictionary with keys:
- - `mask_preds` (Tensor): Mask prediction.
- - `loss_mask` (dict): A dictionary of mask loss components.
- """
- attn_feats = bbox_results['attn_feats']
- sampling_results = bbox_results['sampling_results']
- pos_rois = bbox2roi([res.pos_priors for res in sampling_results])
- attn_feats = torch.cat([
- feats[res.pos_inds]
- for (feats, res) in zip(attn_feats, sampling_results)
- ])
- mask_results = self._mask_forward(stage, x, pos_rois, attn_feats)
- mask_loss_and_target = self.mask_head[stage].loss_and_target(
- mask_preds=mask_results['mask_preds'],
- sampling_results=sampling_results,
- batch_gt_instances=batch_gt_instances,
- rcnn_train_cfg=rcnn_train_cfg)
- mask_results.update(mask_loss_and_target)
- return mask_results
- def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
- batch_data_samples: SampleList) -> dict:
- """Perform forward propagation and loss calculation of the detection
- roi on the features of the upstream network.
- Args:
- x (tuple[Tensor]): List of multi-level img features.
- rpn_results_list (List[:obj:`InstanceData`]): List of region
- proposals.
- batch_data_samples (list[:obj:`DetDataSample`]): The batch
- data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
- Returns:
- dict: a dictionary of loss components of all stage.
- """
- outputs = unpack_gt_instances(batch_data_samples)
- batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \
- = outputs
- object_feats = torch.cat(
- [res.pop('features')[None, ...] for res in rpn_results_list])
- results_list = rpn_results_list
- losses = {}
- for stage in range(self.num_stages):
- stage_loss_weight = self.stage_loss_weights[stage]
- # bbox head forward and loss
- bbox_results = self.bbox_loss(
- stage=stage,
- x=x,
- object_feats=object_feats,
- results_list=results_list,
- batch_img_metas=batch_img_metas,
- batch_gt_instances=batch_gt_instances)
- for name, value in bbox_results['loss_bbox'].items():
- losses[f's{stage}.{name}'] = (
- value * stage_loss_weight if 'loss' in name else value)
- if self.with_mask:
- mask_results = self.mask_loss(
- stage=stage,
- x=x,
- bbox_results=bbox_results,
- batch_gt_instances=batch_gt_instances,
- rcnn_train_cfg=self.train_cfg[stage])
- for name, value in mask_results['loss_mask'].items():
- losses[f's{stage}.{name}'] = (
- value * stage_loss_weight if 'loss' in name else value)
- object_feats = bbox_results['object_feats']
- results_list = bbox_results['results_list']
- return losses
- def predict_bbox(self,
- x: Tuple[Tensor],
- batch_img_metas: List[dict],
- rpn_results_list: InstanceList,
- rcnn_test_cfg: ConfigType,
- rescale: bool = False) -> InstanceList:
- """Perform forward propagation of the bbox head and predict detection
- results on the features of the upstream network.
- Args:
- x(tuple[Tensor]): Feature maps of all scale level.
- batch_img_metas (list[dict]): List of image information.
- rpn_results_list (list[:obj:`InstanceData`]): List of region
- proposals.
- rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
- rescale (bool): If True, return boxes in original image space.
- Defaults to False.
- Returns:
- list[:obj:`InstanceData`]: Detection results of each image
- after the post process.
- Each item usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- """
- proposal_list = [res.bboxes for res in rpn_results_list]
- object_feats = torch.cat(
- [res.pop('features')[None, ...] for res in rpn_results_list])
- if all([proposal.shape[0] == 0 for proposal in proposal_list]):
- # There is no proposal in the whole batch
- return empty_instances(
- batch_img_metas, x[0].device, task_type='bbox')
- for stage in range(self.num_stages):
- rois = bbox2roi(proposal_list)
- bbox_results = self._bbox_forward(stage, x, rois, object_feats,
- batch_img_metas)
- object_feats = bbox_results['object_feats']
- cls_score = bbox_results['cls_score']
- proposal_list = bbox_results['detached_proposals']
- num_classes = self.bbox_head[-1].num_classes
- if self.bbox_head[-1].loss_cls.use_sigmoid:
- cls_score = cls_score.sigmoid()
- else:
- cls_score = cls_score.softmax(-1)[..., :-1]
- topk_inds_list = []
- results_list = []
- for img_id in range(len(batch_img_metas)):
- cls_score_per_img = cls_score[img_id]
- scores_per_img, topk_inds = cls_score_per_img.flatten(0, 1).topk(
- self.test_cfg.max_per_img, sorted=False)
- labels_per_img = topk_inds % num_classes
- bboxes_per_img = proposal_list[img_id][topk_inds // num_classes]
- topk_inds_list.append(topk_inds)
- if rescale and bboxes_per_img.size(0) > 0:
- assert batch_img_metas[img_id].get('scale_factor') is not None
- scale_factor = bboxes_per_img.new_tensor(
- batch_img_metas[img_id]['scale_factor']).repeat((1, 2))
- bboxes_per_img = (
- bboxes_per_img.view(bboxes_per_img.size(0), -1, 4) /
- scale_factor).view(bboxes_per_img.size()[0], -1)
- results = InstanceData()
- results.bboxes = bboxes_per_img
- results.scores = scores_per_img
- results.labels = labels_per_img
- results_list.append(results)
- if self.with_mask:
- for img_id in range(len(batch_img_metas)):
- # add positive information in InstanceData to predict
- # mask results in `mask_head`.
- proposals = bbox_results['detached_proposals'][img_id]
- topk_inds = topk_inds_list[img_id]
- attn_feats = bbox_results['attn_feats'][img_id]
- results_list[img_id].proposals = proposals
- results_list[img_id].topk_inds = topk_inds
- results_list[img_id].attn_feats = attn_feats
- return results_list
- def predict_mask(self,
- x: Tuple[Tensor],
- batch_img_metas: List[dict],
- results_list: InstanceList,
- rescale: bool = False) -> InstanceList:
- """Perform forward propagation of the mask head and predict detection
- results on the features of the upstream network.
- Args:
- x (tuple[Tensor]): Feature maps of all scale level.
- batch_img_metas (list[dict]): List of image information.
- results_list (list[:obj:`InstanceData`]): Detection results of
- each image. Each item usually contains following keys:
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- - proposal (Tensor): Bboxes predicted from bbox_head,
- has a shape (num_instances, 4).
- - topk_inds (Tensor): Topk indices of each image, has
- shape (num_instances, )
- - attn_feats (Tensor): Intermediate feature get from the last
- diihead, has shape (num_instances, feature_dimensions)
- rescale (bool): If True, return boxes in original image space.
- Defaults to False.
- Returns:
- list[:obj:`InstanceData`]: Detection results of each image
- after the post process.
- Each item usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- - masks (Tensor): Has a shape (num_instances, H, W).
- """
- proposal_list = [res.pop('proposals') for res in results_list]
- topk_inds_list = [res.pop('topk_inds') for res in results_list]
- attn_feats = torch.cat(
- [res.pop('attn_feats')[None, ...] for res in results_list])
- rois = bbox2roi(proposal_list)
- if rois.shape[0] == 0:
- results_list = empty_instances(
- batch_img_metas,
- rois.device,
- task_type='mask',
- instance_results=results_list,
- mask_thr_binary=self.test_cfg.mask_thr_binary)
- return results_list
- last_stage = self.num_stages - 1
- mask_results = self._mask_forward(last_stage, x, rois, attn_feats)
- num_imgs = len(batch_img_metas)
- mask_results['mask_preds'] = mask_results['mask_preds'].reshape(
- num_imgs, -1, *mask_results['mask_preds'].size()[1:])
- num_classes = self.bbox_head[-1].num_classes
- mask_preds = []
- for img_id in range(num_imgs):
- topk_inds = topk_inds_list[img_id]
- masks_per_img = mask_results['mask_preds'][img_id].flatten(
- 0, 1)[topk_inds]
- masks_per_img = masks_per_img[:, None,
- ...].repeat(1, num_classes, 1, 1)
- mask_preds.append(masks_per_img)
- results_list = self.mask_head[-1].predict_by_feat(
- mask_preds,
- results_list,
- batch_img_metas,
- rcnn_test_cfg=self.test_cfg,
- rescale=rescale)
- return results_list
- # TODO: Need to refactor later
- def forward(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
- batch_data_samples: SampleList) -> tuple:
- """Network forward process. Usually includes backbone, neck and head
- forward without any post-processing.
- Args:
- x (List[Tensor]): Multi-level features that may have different
- resolutions.
- rpn_results_list (List[:obj:`InstanceData`]): List of region
- proposals.
- batch_data_samples (list[:obj:`DetDataSample`]): The batch
- data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
- Returns
- tuple: A tuple of features from ``bbox_head`` and ``mask_head``
- forward.
- """
- outputs = unpack_gt_instances(batch_data_samples)
- (batch_gt_instances, batch_gt_instances_ignore,
- batch_img_metas) = outputs
- all_stage_bbox_results = []
- object_feats = torch.cat(
- [res.pop('features')[None, ...] for res in rpn_results_list])
- results_list = rpn_results_list
- if self.with_bbox:
- for stage in range(self.num_stages):
- bbox_results = self.bbox_loss(
- stage=stage,
- x=x,
- results_list=results_list,
- object_feats=object_feats,
- batch_img_metas=batch_img_metas,
- batch_gt_instances=batch_gt_instances)
- bbox_results.pop('loss_bbox')
- # torch.jit does not support obj:SamplingResult
- bbox_results.pop('results_list')
- bbox_res = bbox_results.copy()
- bbox_res.pop('sampling_results')
- all_stage_bbox_results.append((bbox_res, ))
- if self.with_mask:
- attn_feats = bbox_results['attn_feats']
- sampling_results = bbox_results['sampling_results']
- pos_rois = bbox2roi(
- [res.pos_priors for res in sampling_results])
- attn_feats = torch.cat([
- feats[res.pos_inds]
- for (feats, res) in zip(attn_feats, sampling_results)
- ])
- mask_results = self._mask_forward(stage, x, pos_rois,
- attn_feats)
- all_stage_bbox_results[-1] += (mask_results, )
- return tuple(all_stage_bbox_results)
|