pisa_loss.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Tuple
  3. import torch
  4. import torch.nn as nn
  5. from torch import Tensor
  6. from mmdet.structures.bbox import bbox_overlaps
  7. from ..task_modules.coders import BaseBBoxCoder
  8. from ..task_modules.samplers import SamplingResult
  9. def isr_p(cls_score: Tensor,
  10. bbox_pred: Tensor,
  11. bbox_targets: Tuple[Tensor],
  12. rois: Tensor,
  13. sampling_results: List[SamplingResult],
  14. loss_cls: nn.Module,
  15. bbox_coder: BaseBBoxCoder,
  16. k: float = 2,
  17. bias: float = 0,
  18. num_class: int = 80) -> tuple:
  19. """Importance-based Sample Reweighting (ISR_P), positive part.
  20. Args:
  21. cls_score (Tensor): Predicted classification scores.
  22. bbox_pred (Tensor): Predicted bbox deltas.
  23. bbox_targets (tuple[Tensor]): A tuple of bbox targets, the are
  24. labels, label_weights, bbox_targets, bbox_weights, respectively.
  25. rois (Tensor): Anchors (single_stage) in shape (n, 4) or RoIs
  26. (two_stage) in shape (n, 5).
  27. sampling_results (:obj:`SamplingResult`): Sampling results.
  28. loss_cls (:obj:`nn.Module`): Classification loss func of the head.
  29. bbox_coder (:obj:`BaseBBoxCoder`): BBox coder of the head.
  30. k (float): Power of the non-linear mapping. Defaults to 2.
  31. bias (float): Shift of the non-linear mapping. Defaults to 0.
  32. num_class (int): Number of classes, defaults to 80.
  33. Return:
  34. tuple([Tensor]): labels, imp_based_label_weights, bbox_targets,
  35. bbox_target_weights
  36. """
  37. labels, label_weights, bbox_targets, bbox_weights = bbox_targets
  38. pos_label_inds = ((labels >= 0) &
  39. (labels < num_class)).nonzero().reshape(-1)
  40. pos_labels = labels[pos_label_inds]
  41. # if no positive samples, return the original targets
  42. num_pos = float(pos_label_inds.size(0))
  43. if num_pos == 0:
  44. return labels, label_weights, bbox_targets, bbox_weights
  45. # merge pos_assigned_gt_inds of per image to a single tensor
  46. gts = list()
  47. last_max_gt = 0
  48. for i in range(len(sampling_results)):
  49. gt_i = sampling_results[i].pos_assigned_gt_inds
  50. gts.append(gt_i + last_max_gt)
  51. if len(gt_i) != 0:
  52. last_max_gt = gt_i.max() + 1
  53. gts = torch.cat(gts)
  54. assert len(gts) == num_pos
  55. cls_score = cls_score.detach()
  56. bbox_pred = bbox_pred.detach()
  57. # For single stage detectors, rois here indicate anchors, in shape (N, 4)
  58. # For two stage detectors, rois are in shape (N, 5)
  59. if rois.size(-1) == 5:
  60. pos_rois = rois[pos_label_inds][:, 1:]
  61. else:
  62. pos_rois = rois[pos_label_inds]
  63. if bbox_pred.size(-1) > 4:
  64. bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, 4)
  65. pos_delta_pred = bbox_pred[pos_label_inds, pos_labels].view(-1, 4)
  66. else:
  67. pos_delta_pred = bbox_pred[pos_label_inds].view(-1, 4)
  68. # compute iou of the predicted bbox and the corresponding GT
  69. pos_delta_target = bbox_targets[pos_label_inds].view(-1, 4)
  70. pos_bbox_pred = bbox_coder.decode(pos_rois, pos_delta_pred)
  71. target_bbox_pred = bbox_coder.decode(pos_rois, pos_delta_target)
  72. ious = bbox_overlaps(pos_bbox_pred, target_bbox_pred, is_aligned=True)
  73. pos_imp_weights = label_weights[pos_label_inds]
  74. # Two steps to compute IoU-HLR. Samples are first sorted by IoU locally,
  75. # then sorted again within the same-rank group
  76. max_l_num = pos_labels.bincount().max()
  77. for label in pos_labels.unique():
  78. l_inds = (pos_labels == label).nonzero().view(-1)
  79. l_gts = gts[l_inds]
  80. for t in l_gts.unique():
  81. t_inds = l_inds[l_gts == t]
  82. t_ious = ious[t_inds]
  83. _, t_iou_rank_idx = t_ious.sort(descending=True)
  84. _, t_iou_rank = t_iou_rank_idx.sort()
  85. ious[t_inds] += max_l_num - t_iou_rank.float()
  86. l_ious = ious[l_inds]
  87. _, l_iou_rank_idx = l_ious.sort(descending=True)
  88. _, l_iou_rank = l_iou_rank_idx.sort() # IoU-HLR
  89. # linearly map HLR to label weights
  90. pos_imp_weights[l_inds] *= (max_l_num - l_iou_rank.float()) / max_l_num
  91. pos_imp_weights = (bias + pos_imp_weights * (1 - bias)).pow(k)
  92. # normalize to make the new weighted loss value equal to the original loss
  93. pos_loss_cls = loss_cls(
  94. cls_score[pos_label_inds], pos_labels, reduction_override='none')
  95. if pos_loss_cls.dim() > 1:
  96. ori_pos_loss_cls = pos_loss_cls * label_weights[pos_label_inds][:,
  97. None]
  98. new_pos_loss_cls = pos_loss_cls * pos_imp_weights[:, None]
  99. else:
  100. ori_pos_loss_cls = pos_loss_cls * label_weights[pos_label_inds]
  101. new_pos_loss_cls = pos_loss_cls * pos_imp_weights
  102. pos_loss_cls_ratio = ori_pos_loss_cls.sum() / new_pos_loss_cls.sum()
  103. pos_imp_weights = pos_imp_weights * pos_loss_cls_ratio
  104. label_weights[pos_label_inds] = pos_imp_weights
  105. bbox_targets = labels, label_weights, bbox_targets, bbox_weights
  106. return bbox_targets
  107. def carl_loss(cls_score: Tensor,
  108. labels: Tensor,
  109. bbox_pred: Tensor,
  110. bbox_targets: Tensor,
  111. loss_bbox: nn.Module,
  112. k: float = 1,
  113. bias: float = 0.2,
  114. avg_factor: Optional[int] = None,
  115. sigmoid: bool = False,
  116. num_class: int = 80) -> dict:
  117. """Classification-Aware Regression Loss (CARL).
  118. Args:
  119. cls_score (Tensor): Predicted classification scores.
  120. labels (Tensor): Targets of classification.
  121. bbox_pred (Tensor): Predicted bbox deltas.
  122. bbox_targets (Tensor): Target of bbox regression.
  123. loss_bbox (func): Regression loss func of the head.
  124. bbox_coder (obj): BBox coder of the head.
  125. k (float): Power of the non-linear mapping. Defaults to 1.
  126. bias (float): Shift of the non-linear mapping. Defaults to 0.2.
  127. avg_factor (int, optional): Average factor used in regression loss.
  128. sigmoid (bool): Activation of the classification score.
  129. num_class (int): Number of classes, defaults to 80.
  130. Return:
  131. dict: CARL loss dict.
  132. """
  133. pos_label_inds = ((labels >= 0) &
  134. (labels < num_class)).nonzero().reshape(-1)
  135. if pos_label_inds.numel() == 0:
  136. return dict(loss_carl=cls_score.sum()[None] * 0.)
  137. pos_labels = labels[pos_label_inds]
  138. # multiply pos_cls_score with the corresponding bbox weight
  139. # and remain gradient
  140. if sigmoid:
  141. pos_cls_score = cls_score.sigmoid()[pos_label_inds, pos_labels]
  142. else:
  143. pos_cls_score = cls_score.softmax(-1)[pos_label_inds, pos_labels]
  144. carl_loss_weights = (bias + (1 - bias) * pos_cls_score).pow(k)
  145. # normalize carl_loss_weight to make its sum equal to num positive
  146. num_pos = float(pos_cls_score.size(0))
  147. weight_ratio = num_pos / carl_loss_weights.sum()
  148. carl_loss_weights *= weight_ratio
  149. if avg_factor is None:
  150. avg_factor = bbox_targets.size(0)
  151. # if is class agnostic, bbox pred is in shape (N, 4)
  152. # otherwise, bbox pred is in shape (N, #classes, 4)
  153. if bbox_pred.size(-1) > 4:
  154. bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, 4)
  155. pos_bbox_preds = bbox_pred[pos_label_inds, pos_labels]
  156. else:
  157. pos_bbox_preds = bbox_pred[pos_label_inds]
  158. ori_loss_reg = loss_bbox(
  159. pos_bbox_preds,
  160. bbox_targets[pos_label_inds],
  161. reduction_override='none') / avg_factor
  162. loss_carl = (ori_loss_reg * carl_loss_weights[:, None]).sum()
  163. return dict(loss_carl=loss_carl[None])