dynamic_soft_label_assigner.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Tuple
  3. import torch
  4. import torch.nn.functional as F
  5. from mmengine.structures import InstanceData
  6. from torch import Tensor
  7. from mmdet.registry import TASK_UTILS
  8. from mmdet.structures.bbox import BaseBoxes
  9. from mmdet.utils import ConfigType
  10. from .assign_result import AssignResult
  11. from .base_assigner import BaseAssigner
  12. INF = 100000000
  13. EPS = 1.0e-7
  14. def center_of_mass(masks: Tensor, eps: float = 1e-7) -> Tensor:
  15. """Compute the masks center of mass.
  16. Args:
  17. masks: Mask tensor, has shape (num_masks, H, W).
  18. eps: a small number to avoid normalizer to be zero.
  19. Defaults to 1e-7.
  20. Returns:
  21. Tensor: The masks center of mass. Has shape (num_masks, 2).
  22. """
  23. n, h, w = masks.shape
  24. grid_h = torch.arange(h, device=masks.device)[:, None]
  25. grid_w = torch.arange(w, device=masks.device)
  26. normalizer = masks.sum(dim=(1, 2)).float().clamp(min=eps)
  27. center_y = (masks * grid_h).sum(dim=(1, 2)) / normalizer
  28. center_x = (masks * grid_w).sum(dim=(1, 2)) / normalizer
  29. center = torch.cat([center_x[:, None], center_y[:, None]], dim=1)
  30. return center
  31. @TASK_UTILS.register_module()
  32. class DynamicSoftLabelAssigner(BaseAssigner):
  33. """Computes matching between predictions and ground truth with dynamic soft
  34. label assignment.
  35. Args:
  36. soft_center_radius (float): Radius of the soft center prior.
  37. Defaults to 3.0.
  38. topk (int): Select top-k predictions to calculate dynamic k
  39. best matches for each gt. Defaults to 13.
  40. iou_weight (float): The scale factor of iou cost. Defaults to 3.0.
  41. iou_calculator (ConfigType): Config of overlaps Calculator.
  42. Defaults to dict(type='BboxOverlaps2D').
  43. """
  44. def __init__(
  45. self,
  46. soft_center_radius: float = 3.0,
  47. topk: int = 13,
  48. iou_weight: float = 3.0,
  49. iou_calculator: ConfigType = dict(type='BboxOverlaps2D')
  50. ) -> None:
  51. self.soft_center_radius = soft_center_radius
  52. self.topk = topk
  53. self.iou_weight = iou_weight
  54. self.iou_calculator = TASK_UTILS.build(iou_calculator)
  55. def assign(self,
  56. pred_instances: InstanceData,
  57. gt_instances: InstanceData,
  58. gt_instances_ignore: Optional[InstanceData] = None,
  59. **kwargs) -> AssignResult:
  60. """Assign gt to priors.
  61. Args:
  62. pred_instances (:obj:`InstanceData`): Instances of model
  63. predictions. It includes ``priors``, and the priors can
  64. be anchors or points, or the bboxes predicted by the
  65. previous stage, has shape (n, 4). The bboxes predicted by
  66. the current model or stage will be named ``bboxes``,
  67. ``labels``, and ``scores``, the same as the ``InstanceData``
  68. in other places.
  69. gt_instances (:obj:`InstanceData`): Ground truth of instance
  70. annotations. It usually includes ``bboxes``, with shape (k, 4),
  71. and ``labels``, with shape (k, ).
  72. gt_instances_ignore (:obj:`InstanceData`, optional): Instances
  73. to be ignored during training. It includes ``bboxes``
  74. attribute data that is ignored during training and testing.
  75. Defaults to None.
  76. Returns:
  77. obj:`AssignResult`: The assigned result.
  78. """
  79. gt_bboxes = gt_instances.bboxes
  80. gt_labels = gt_instances.labels
  81. num_gt = gt_bboxes.size(0)
  82. decoded_bboxes = pred_instances.bboxes
  83. pred_scores = pred_instances.scores
  84. priors = pred_instances.priors
  85. num_bboxes = decoded_bboxes.size(0)
  86. # assign 0 by default
  87. assigned_gt_inds = decoded_bboxes.new_full((num_bboxes, ),
  88. 0,
  89. dtype=torch.long)
  90. if num_gt == 0 or num_bboxes == 0:
  91. # No ground truth or boxes, return empty assignment
  92. max_overlaps = decoded_bboxes.new_zeros((num_bboxes, ))
  93. if num_gt == 0:
  94. # No truth, assign everything to background
  95. assigned_gt_inds[:] = 0
  96. assigned_labels = decoded_bboxes.new_full((num_bboxes, ),
  97. -1,
  98. dtype=torch.long)
  99. return AssignResult(
  100. num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
  101. prior_center = priors[:, :2]
  102. if isinstance(gt_bboxes, BaseBoxes):
  103. is_in_gts = gt_bboxes.find_inside_points(prior_center)
  104. else:
  105. # Tensor boxes will be treated as horizontal boxes by defaults
  106. lt_ = prior_center[:, None] - gt_bboxes[:, :2]
  107. rb_ = gt_bboxes[:, 2:] - prior_center[:, None]
  108. deltas = torch.cat([lt_, rb_], dim=-1)
  109. is_in_gts = deltas.min(dim=-1).values > 0
  110. valid_mask = is_in_gts.sum(dim=1) > 0
  111. valid_decoded_bbox = decoded_bboxes[valid_mask]
  112. valid_pred_scores = pred_scores[valid_mask]
  113. num_valid = valid_decoded_bbox.size(0)
  114. if num_valid == 0:
  115. # No ground truth or boxes, return empty assignment
  116. max_overlaps = decoded_bboxes.new_zeros((num_bboxes, ))
  117. assigned_labels = decoded_bboxes.new_full((num_bboxes, ),
  118. -1,
  119. dtype=torch.long)
  120. return AssignResult(
  121. num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
  122. if hasattr(gt_instances, 'masks'):
  123. gt_center = center_of_mass(gt_instances.masks, eps=EPS)
  124. elif isinstance(gt_bboxes, BaseBoxes):
  125. gt_center = gt_bboxes.centers
  126. else:
  127. # Tensor boxes will be treated as horizontal boxes by defaults
  128. gt_center = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2.0
  129. valid_prior = priors[valid_mask]
  130. strides = valid_prior[:, 2]
  131. distance = (valid_prior[:, None, :2] - gt_center[None, :, :]
  132. ).pow(2).sum(-1).sqrt() / strides[:, None]
  133. soft_center_prior = torch.pow(10, distance - self.soft_center_radius)
  134. pairwise_ious = self.iou_calculator(valid_decoded_bbox, gt_bboxes)
  135. iou_cost = -torch.log(pairwise_ious + EPS) * self.iou_weight
  136. gt_onehot_label = (
  137. F.one_hot(gt_labels.to(torch.int64),
  138. pred_scores.shape[-1]).float().unsqueeze(0).repeat(
  139. num_valid, 1, 1))
  140. valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1)
  141. soft_label = gt_onehot_label * pairwise_ious[..., None]
  142. scale_factor = soft_label - valid_pred_scores.sigmoid()
  143. soft_cls_cost = F.binary_cross_entropy_with_logits(
  144. valid_pred_scores, soft_label,
  145. reduction='none') * scale_factor.abs().pow(2.0)
  146. soft_cls_cost = soft_cls_cost.sum(dim=-1)
  147. cost_matrix = soft_cls_cost + iou_cost + soft_center_prior
  148. matched_pred_ious, matched_gt_inds = self.dynamic_k_matching(
  149. cost_matrix, pairwise_ious, num_gt, valid_mask)
  150. # convert to AssignResult format
  151. assigned_gt_inds[valid_mask] = matched_gt_inds + 1
  152. assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
  153. assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long()
  154. max_overlaps = assigned_gt_inds.new_full((num_bboxes, ),
  155. -INF,
  156. dtype=torch.float32)
  157. max_overlaps[valid_mask] = matched_pred_ious
  158. return AssignResult(
  159. num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
  160. def dynamic_k_matching(self, cost: Tensor, pairwise_ious: Tensor,
  161. num_gt: int,
  162. valid_mask: Tensor) -> Tuple[Tensor, Tensor]:
  163. """Use IoU and matching cost to calculate the dynamic top-k positive
  164. targets. Same as SimOTA.
  165. Args:
  166. cost (Tensor): Cost matrix.
  167. pairwise_ious (Tensor): Pairwise iou matrix.
  168. num_gt (int): Number of gt.
  169. valid_mask (Tensor): Mask for valid bboxes.
  170. Returns:
  171. tuple: matched ious and gt indexes.
  172. """
  173. matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
  174. # select candidate topk ious for dynamic-k calculation
  175. candidate_topk = min(self.topk, pairwise_ious.size(0))
  176. topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0)
  177. # calculate dynamic k for each gt
  178. dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
  179. for gt_idx in range(num_gt):
  180. _, pos_idx = torch.topk(
  181. cost[:, gt_idx], k=dynamic_ks[gt_idx], largest=False)
  182. matching_matrix[:, gt_idx][pos_idx] = 1
  183. del topk_ious, dynamic_ks, pos_idx
  184. prior_match_gt_mask = matching_matrix.sum(1) > 1
  185. if prior_match_gt_mask.sum() > 0:
  186. cost_min, cost_argmin = torch.min(
  187. cost[prior_match_gt_mask, :], dim=1)
  188. matching_matrix[prior_match_gt_mask, :] *= 0
  189. matching_matrix[prior_match_gt_mask, cost_argmin] = 1
  190. # get foreground mask inside box and center prior
  191. fg_mask_inboxes = matching_matrix.sum(1) > 0
  192. valid_mask[valid_mask.clone()] = fg_mask_inboxes
  193. matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)
  194. matched_pred_ious = (matching_matrix *
  195. pairwise_ious).sum(1)[fg_mask_inboxes]
  196. return matched_pred_ious, matched_gt_inds