123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List, Optional, Tuple
- import torch
- import torch.nn as nn
- from torch import Tensor
- from mmdet.structures.bbox import bbox_overlaps
- from ..task_modules.coders import BaseBBoxCoder
- from ..task_modules.samplers import SamplingResult
- def isr_p(cls_score: Tensor,
- bbox_pred: Tensor,
- bbox_targets: Tuple[Tensor],
- rois: Tensor,
- sampling_results: List[SamplingResult],
- loss_cls: nn.Module,
- bbox_coder: BaseBBoxCoder,
- k: float = 2,
- bias: float = 0,
- num_class: int = 80) -> tuple:
- """Importance-based Sample Reweighting (ISR_P), positive part.
- Args:
- cls_score (Tensor): Predicted classification scores.
- bbox_pred (Tensor): Predicted bbox deltas.
- bbox_targets (tuple[Tensor]): A tuple of bbox targets, the are
- labels, label_weights, bbox_targets, bbox_weights, respectively.
- rois (Tensor): Anchors (single_stage) in shape (n, 4) or RoIs
- (two_stage) in shape (n, 5).
- sampling_results (:obj:`SamplingResult`): Sampling results.
- loss_cls (:obj:`nn.Module`): Classification loss func of the head.
- bbox_coder (:obj:`BaseBBoxCoder`): BBox coder of the head.
- k (float): Power of the non-linear mapping. Defaults to 2.
- bias (float): Shift of the non-linear mapping. Defaults to 0.
- num_class (int): Number of classes, defaults to 80.
- Return:
- tuple([Tensor]): labels, imp_based_label_weights, bbox_targets,
- bbox_target_weights
- """
- labels, label_weights, bbox_targets, bbox_weights = bbox_targets
- pos_label_inds = ((labels >= 0) &
- (labels < num_class)).nonzero().reshape(-1)
- pos_labels = labels[pos_label_inds]
- # if no positive samples, return the original targets
- num_pos = float(pos_label_inds.size(0))
- if num_pos == 0:
- return labels, label_weights, bbox_targets, bbox_weights
- # merge pos_assigned_gt_inds of per image to a single tensor
- gts = list()
- last_max_gt = 0
- for i in range(len(sampling_results)):
- gt_i = sampling_results[i].pos_assigned_gt_inds
- gts.append(gt_i + last_max_gt)
- if len(gt_i) != 0:
- last_max_gt = gt_i.max() + 1
- gts = torch.cat(gts)
- assert len(gts) == num_pos
- cls_score = cls_score.detach()
- bbox_pred = bbox_pred.detach()
- # For single stage detectors, rois here indicate anchors, in shape (N, 4)
- # For two stage detectors, rois are in shape (N, 5)
- if rois.size(-1) == 5:
- pos_rois = rois[pos_label_inds][:, 1:]
- else:
- pos_rois = rois[pos_label_inds]
- if bbox_pred.size(-1) > 4:
- bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, 4)
- pos_delta_pred = bbox_pred[pos_label_inds, pos_labels].view(-1, 4)
- else:
- pos_delta_pred = bbox_pred[pos_label_inds].view(-1, 4)
- # compute iou of the predicted bbox and the corresponding GT
- pos_delta_target = bbox_targets[pos_label_inds].view(-1, 4)
- pos_bbox_pred = bbox_coder.decode(pos_rois, pos_delta_pred)
- target_bbox_pred = bbox_coder.decode(pos_rois, pos_delta_target)
- ious = bbox_overlaps(pos_bbox_pred, target_bbox_pred, is_aligned=True)
- pos_imp_weights = label_weights[pos_label_inds]
- # Two steps to compute IoU-HLR. Samples are first sorted by IoU locally,
- # then sorted again within the same-rank group
- max_l_num = pos_labels.bincount().max()
- for label in pos_labels.unique():
- l_inds = (pos_labels == label).nonzero().view(-1)
- l_gts = gts[l_inds]
- for t in l_gts.unique():
- t_inds = l_inds[l_gts == t]
- t_ious = ious[t_inds]
- _, t_iou_rank_idx = t_ious.sort(descending=True)
- _, t_iou_rank = t_iou_rank_idx.sort()
- ious[t_inds] += max_l_num - t_iou_rank.float()
- l_ious = ious[l_inds]
- _, l_iou_rank_idx = l_ious.sort(descending=True)
- _, l_iou_rank = l_iou_rank_idx.sort() # IoU-HLR
- # linearly map HLR to label weights
- pos_imp_weights[l_inds] *= (max_l_num - l_iou_rank.float()) / max_l_num
- pos_imp_weights = (bias + pos_imp_weights * (1 - bias)).pow(k)
- # normalize to make the new weighted loss value equal to the original loss
- pos_loss_cls = loss_cls(
- cls_score[pos_label_inds], pos_labels, reduction_override='none')
- if pos_loss_cls.dim() > 1:
- ori_pos_loss_cls = pos_loss_cls * label_weights[pos_label_inds][:,
- None]
- new_pos_loss_cls = pos_loss_cls * pos_imp_weights[:, None]
- else:
- ori_pos_loss_cls = pos_loss_cls * label_weights[pos_label_inds]
- new_pos_loss_cls = pos_loss_cls * pos_imp_weights
- pos_loss_cls_ratio = ori_pos_loss_cls.sum() / new_pos_loss_cls.sum()
- pos_imp_weights = pos_imp_weights * pos_loss_cls_ratio
- label_weights[pos_label_inds] = pos_imp_weights
- bbox_targets = labels, label_weights, bbox_targets, bbox_weights
- return bbox_targets
- def carl_loss(cls_score: Tensor,
- labels: Tensor,
- bbox_pred: Tensor,
- bbox_targets: Tensor,
- loss_bbox: nn.Module,
- k: float = 1,
- bias: float = 0.2,
- avg_factor: Optional[int] = None,
- sigmoid: bool = False,
- num_class: int = 80) -> dict:
- """Classification-Aware Regression Loss (CARL).
- Args:
- cls_score (Tensor): Predicted classification scores.
- labels (Tensor): Targets of classification.
- bbox_pred (Tensor): Predicted bbox deltas.
- bbox_targets (Tensor): Target of bbox regression.
- loss_bbox (func): Regression loss func of the head.
- bbox_coder (obj): BBox coder of the head.
- k (float): Power of the non-linear mapping. Defaults to 1.
- bias (float): Shift of the non-linear mapping. Defaults to 0.2.
- avg_factor (int, optional): Average factor used in regression loss.
- sigmoid (bool): Activation of the classification score.
- num_class (int): Number of classes, defaults to 80.
- Return:
- dict: CARL loss dict.
- """
- pos_label_inds = ((labels >= 0) &
- (labels < num_class)).nonzero().reshape(-1)
- if pos_label_inds.numel() == 0:
- return dict(loss_carl=cls_score.sum()[None] * 0.)
- pos_labels = labels[pos_label_inds]
- # multiply pos_cls_score with the corresponding bbox weight
- # and remain gradient
- if sigmoid:
- pos_cls_score = cls_score.sigmoid()[pos_label_inds, pos_labels]
- else:
- pos_cls_score = cls_score.softmax(-1)[pos_label_inds, pos_labels]
- carl_loss_weights = (bias + (1 - bias) * pos_cls_score).pow(k)
- # normalize carl_loss_weight to make its sum equal to num positive
- num_pos = float(pos_cls_score.size(0))
- weight_ratio = num_pos / carl_loss_weights.sum()
- carl_loss_weights *= weight_ratio
- if avg_factor is None:
- avg_factor = bbox_targets.size(0)
- # if is class agnostic, bbox pred is in shape (N, 4)
- # otherwise, bbox pred is in shape (N, #classes, 4)
- if bbox_pred.size(-1) > 4:
- bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, 4)
- pos_bbox_preds = bbox_pred[pos_label_inds, pos_labels]
- else:
- pos_bbox_preds = bbox_pred[pos_label_inds]
- ori_loss_reg = loss_bbox(
- pos_bbox_preds,
- bbox_targets[pos_label_inds],
- reduction_override='none') / avg_factor
- loss_carl = (ori_loss_reg * carl_loss_weights[:, None]).sum()
- return dict(loss_carl=loss_carl[None])
|