123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List, Tuple
- from torch import Tensor
- from mmdet.models.task_modules import SamplingResult
- from mmdet.registry import MODELS
- from mmdet.structures import DetDataSample
- from mmdet.structures.bbox import bbox2roi
- from mmdet.utils import InstanceList
- from ..losses.pisa_loss import carl_loss, isr_p
- from ..utils import unpack_gt_instances
- from .standard_roi_head import StandardRoIHead
- @MODELS.register_module()
- class PISARoIHead(StandardRoIHead):
- r"""The RoI head for `Prime Sample Attention in Object Detection
- <https://arxiv.org/abs/1904.04821>`_."""
- def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
- batch_data_samples: List[DetDataSample]) -> dict:
- """Perform forward propagation and loss calculation of the detection
- roi on the features of the upstream network.
- Args:
- x (tuple[Tensor]): List of multi-level img features.
- rpn_results_list (list[:obj:`InstanceData`]): List of region
- proposals.
- batch_data_samples (list[:obj:`DetDataSample`]): The batch
- data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
- Returns:
- dict[str, Tensor]: A dictionary of loss components
- """
- assert len(rpn_results_list) == len(batch_data_samples)
- outputs = unpack_gt_instances(batch_data_samples)
- batch_gt_instances, batch_gt_instances_ignore, _ = outputs
- # assign gts and sample proposals
- num_imgs = len(batch_data_samples)
- sampling_results = []
- neg_label_weights = []
- for i in range(num_imgs):
- # rename rpn_results.bboxes to rpn_results.priors
- rpn_results = rpn_results_list[i]
- rpn_results.priors = rpn_results.pop('bboxes')
- assign_result = self.bbox_assigner.assign(
- rpn_results, batch_gt_instances[i],
- batch_gt_instances_ignore[i])
- sampling_result = self.bbox_sampler.sample(
- assign_result,
- rpn_results,
- batch_gt_instances[i],
- feats=[lvl_feat[i][None] for lvl_feat in x])
- if isinstance(sampling_result, tuple):
- sampling_result, neg_label_weight = sampling_result
- sampling_results.append(sampling_result)
- neg_label_weights.append(neg_label_weight)
- losses = dict()
- # bbox head forward and loss
- if self.with_bbox:
- bbox_results = self.bbox_loss(
- x, sampling_results, neg_label_weights=neg_label_weights)
- losses.update(bbox_results['loss_bbox'])
- # mask head forward and loss
- if self.with_mask:
- mask_results = self.mask_loss(x, sampling_results,
- bbox_results['bbox_feats'],
- batch_gt_instances)
- losses.update(mask_results['loss_mask'])
- return losses
- def bbox_loss(self,
- x: Tuple[Tensor],
- sampling_results: List[SamplingResult],
- neg_label_weights: List[Tensor] = None) -> dict:
- """Perform forward propagation and loss calculation of the bbox head on
- the features of the upstream network.
- Args:
- x (tuple[Tensor]): List of multi-level img features.
- sampling_results (list["obj:`SamplingResult`]): Sampling results.
- Returns:
- dict[str, Tensor]: Usually returns a dictionary with keys:
- - `cls_score` (Tensor): Classification scores.
- - `bbox_pred` (Tensor): Box energies / deltas.
- - `bbox_feats` (Tensor): Extract bbox RoI features.
- - `loss_bbox` (dict): A dictionary of bbox loss components.
- """
- rois = bbox2roi([res.priors for res in sampling_results])
- bbox_results = self._bbox_forward(x, rois)
- bbox_targets = self.bbox_head.get_targets(sampling_results,
- self.train_cfg)
- # neg_label_weights obtained by sampler is image-wise, mapping back to
- # the corresponding location in label weights
- if neg_label_weights[0] is not None:
- label_weights = bbox_targets[1]
- cur_num_rois = 0
- for i in range(len(sampling_results)):
- num_pos = sampling_results[i].pos_inds.size(0)
- num_neg = sampling_results[i].neg_inds.size(0)
- label_weights[cur_num_rois + num_pos:cur_num_rois + num_pos +
- num_neg] = neg_label_weights[i]
- cur_num_rois += num_pos + num_neg
- cls_score = bbox_results['cls_score']
- bbox_pred = bbox_results['bbox_pred']
- # Apply ISR-P
- isr_cfg = self.train_cfg.get('isr', None)
- if isr_cfg is not None:
- bbox_targets = isr_p(
- cls_score,
- bbox_pred,
- bbox_targets,
- rois,
- sampling_results,
- self.bbox_head.loss_cls,
- self.bbox_head.bbox_coder,
- **isr_cfg,
- num_class=self.bbox_head.num_classes)
- loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, rois,
- *bbox_targets)
- # Add CARL Loss
- carl_cfg = self.train_cfg.get('carl', None)
- if carl_cfg is not None:
- loss_carl = carl_loss(
- cls_score,
- bbox_targets[0],
- bbox_pred,
- bbox_targets[2],
- self.bbox_head.loss_bbox,
- **carl_cfg,
- num_class=self.bbox_head.num_classes)
- loss_bbox.update(loss_carl)
- bbox_results.update(loss_bbox=loss_bbox)
- return bbox_results
|