assigner.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Tuple
  3. import torch
  4. import torch.nn.functional as F
  5. from mmdet.models.task_modules.assigners import AssignResult, SimOTAAssigner
  6. from mmdet.utils import ConfigType
  7. from mmengine.structures import InstanceData
  8. from mmyolo.registry import MODELS, TASK_UTILS
  9. from torch import Tensor
  10. INF = 100000.0
  11. EPS = 1.0e-7
  12. @TASK_UTILS.register_module()
  13. class PoseSimOTAAssigner(SimOTAAssigner):
  14. def __init__(self,
  15. center_radius: float = 2.5,
  16. candidate_topk: int = 10,
  17. iou_weight: float = 3.0,
  18. cls_weight: float = 1.0,
  19. oks_weight: float = 0.0,
  20. vis_weight: float = 0.0,
  21. iou_calculator: ConfigType = dict(type='BboxOverlaps2D'),
  22. oks_calculator: ConfigType = dict(type='OksLoss')):
  23. self.center_radius = center_radius
  24. self.candidate_topk = candidate_topk
  25. self.iou_weight = iou_weight
  26. self.cls_weight = cls_weight
  27. self.oks_weight = oks_weight
  28. self.vis_weight = vis_weight
  29. self.iou_calculator = TASK_UTILS.build(iou_calculator)
  30. self.oks_calculator = MODELS.build(oks_calculator)
  31. def assign(self,
  32. pred_instances: InstanceData,
  33. gt_instances: InstanceData,
  34. gt_instances_ignore: Optional[InstanceData] = None,
  35. **kwargs) -> AssignResult:
  36. """Assign gt to priors using SimOTA.
  37. Args:
  38. pred_instances (:obj:`InstanceData`): Instances of model
  39. predictions. It includes ``priors``, and the priors can
  40. be anchors or points, or the bboxes predicted by the
  41. previous stage, has shape (n, 4). The bboxes predicted by
  42. the current model or stage will be named ``bboxes``,
  43. ``labels``, and ``scores``, the same as the ``InstanceData``
  44. in other places.
  45. gt_instances (:obj:`InstanceData`): Ground truth of instance
  46. annotations. It usually includes ``bboxes``, with shape (k, 4),
  47. and ``labels``, with shape (k, ).
  48. gt_instances_ignore (:obj:`InstanceData`, optional): Instances
  49. to be ignored during training. It includes ``bboxes``
  50. attribute data that is ignored during training and testing.
  51. Defaults to None.
  52. Returns:
  53. obj:`AssignResult`: The assigned result.
  54. """
  55. gt_bboxes = gt_instances.bboxes
  56. gt_labels = gt_instances.labels
  57. gt_keypoints = gt_instances.keypoints
  58. gt_keypoints_visible = gt_instances.keypoints_visible
  59. num_gt = gt_bboxes.size(0)
  60. decoded_bboxes = pred_instances.bboxes[..., :4]
  61. pred_kpts = pred_instances.bboxes[..., 4:]
  62. pred_kpts = pred_kpts.reshape(*pred_kpts.shape[:-1], -1, 3)
  63. pred_kpts_vis = pred_kpts[..., -1]
  64. pred_kpts = pred_kpts[..., :2]
  65. pred_scores = pred_instances.scores
  66. priors = pred_instances.priors
  67. num_bboxes = decoded_bboxes.size(0)
  68. # assign 0 by default
  69. assigned_gt_inds = decoded_bboxes.new_full((num_bboxes, ),
  70. 0,
  71. dtype=torch.long)
  72. if num_gt == 0 or num_bboxes == 0:
  73. # No ground truth or boxes, return empty assignment
  74. max_overlaps = decoded_bboxes.new_zeros((num_bboxes, ))
  75. assigned_labels = decoded_bboxes.new_full((num_bboxes, ),
  76. -1,
  77. dtype=torch.long)
  78. return AssignResult(
  79. num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
  80. valid_mask, is_in_boxes_and_center = self.get_in_gt_and_in_center_info(
  81. priors, gt_bboxes)
  82. valid_decoded_bbox = decoded_bboxes[valid_mask]
  83. valid_pred_scores = pred_scores[valid_mask]
  84. valid_pred_kpts = pred_kpts[valid_mask]
  85. valid_pred_kpts_vis = pred_kpts_vis[valid_mask]
  86. num_valid = valid_decoded_bbox.size(0)
  87. if num_valid == 0:
  88. # No valid bboxes, return empty assignment
  89. max_overlaps = decoded_bboxes.new_zeros((num_bboxes, ))
  90. assigned_labels = decoded_bboxes.new_full((num_bboxes, ),
  91. -1,
  92. dtype=torch.long)
  93. return AssignResult(
  94. num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
  95. cost_matrix = (~is_in_boxes_and_center) * INF
  96. # calculate iou
  97. pairwise_ious = self.iou_calculator(valid_decoded_bbox, gt_bboxes)
  98. if self.iou_weight > 0:
  99. iou_cost = -torch.log(pairwise_ious + EPS)
  100. cost_matrix = cost_matrix + iou_cost * self.iou_weight
  101. # calculate oks
  102. pairwise_oks = self.oks_calculator.compute_oks(
  103. valid_pred_kpts.unsqueeze(1), # [num_valid, -1, k, 2]
  104. gt_keypoints.unsqueeze(0), # [1, num_gt, k, 2]
  105. gt_keypoints_visible.unsqueeze(0), # [1, num_gt, k]
  106. bboxes=gt_bboxes.unsqueeze(0), # [1, num_gt, 4]
  107. ) # -> [num_valid, num_gt]
  108. if self.oks_weight > 0:
  109. oks_cost = -torch.log(pairwise_oks + EPS)
  110. cost_matrix = cost_matrix + oks_cost * self.oks_weight
  111. # calculate cls
  112. if self.cls_weight > 0:
  113. gt_onehot_label = (
  114. F.one_hot(gt_labels.to(torch.int64),
  115. pred_scores.shape[-1]).float().unsqueeze(0).repeat(
  116. num_valid, 1, 1))
  117. valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(
  118. 1, num_gt, 1)
  119. # disable AMP autocast to avoid overflow
  120. with torch.cuda.amp.autocast(enabled=False):
  121. cls_cost = (
  122. F.binary_cross_entropy(
  123. valid_pred_scores.to(dtype=torch.float32),
  124. gt_onehot_label,
  125. reduction='none',
  126. ).sum(-1).to(dtype=valid_pred_scores.dtype))
  127. cost_matrix = cost_matrix + cls_cost * self.cls_weight
  128. # calculate vis
  129. if self.vis_weight > 0:
  130. valid_pred_kpts_vis = valid_pred_kpts_vis.sigmoid().unsqueeze(
  131. 1).repeat(1, num_gt, 1) # [num_valid, 1, k]
  132. gt_kpt_vis = gt_keypoints_visible.unsqueeze(
  133. 0).float() # [1, num_gt, k]
  134. with torch.cuda.amp.autocast(enabled=False):
  135. vis_cost = (
  136. F.binary_cross_entropy(
  137. valid_pred_kpts_vis.to(dtype=torch.float32),
  138. gt_kpt_vis.repeat(num_valid, 1, 1),
  139. reduction='none',
  140. ).sum(-1).to(dtype=valid_pred_kpts_vis.dtype))
  141. cost_matrix = cost_matrix + vis_cost * self.vis_weight
  142. # mixed metric
  143. pairwise_oks = pairwise_oks.pow(0.5)
  144. matched_pred_oks, matched_gt_inds = \
  145. self.dynamic_k_matching(
  146. cost_matrix, pairwise_ious, pairwise_oks, num_gt, valid_mask)
  147. # convert to AssignResult format
  148. assigned_gt_inds[valid_mask] = matched_gt_inds + 1
  149. assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
  150. assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long()
  151. max_overlaps = assigned_gt_inds.new_full((num_bboxes, ),
  152. -INF,
  153. dtype=torch.float32)
  154. max_overlaps[valid_mask] = matched_pred_oks
  155. return AssignResult(
  156. num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
  157. def dynamic_k_matching(self, cost: Tensor, pairwise_ious: Tensor,
  158. pairwise_oks: Tensor, num_gt: int,
  159. valid_mask: Tensor) -> Tuple[Tensor, Tensor]:
  160. """Use IoU and matching cost to calculate the dynamic top-k positive
  161. targets."""
  162. matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
  163. # select candidate topk ious for dynamic-k calculation
  164. candidate_topk = min(self.candidate_topk, pairwise_ious.size(0))
  165. topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0)
  166. # calculate dynamic k for each gt
  167. dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
  168. for gt_idx in range(num_gt):
  169. _, pos_idx = torch.topk(
  170. cost[:, gt_idx], k=dynamic_ks[gt_idx], largest=False)
  171. matching_matrix[:, gt_idx][pos_idx] = 1
  172. del topk_ious, dynamic_ks, pos_idx
  173. prior_match_gt_mask = matching_matrix.sum(1) > 1
  174. if prior_match_gt_mask.sum() > 0:
  175. cost_min, cost_argmin = torch.min(
  176. cost[prior_match_gt_mask, :], dim=1)
  177. matching_matrix[prior_match_gt_mask, :] *= 0
  178. matching_matrix[prior_match_gt_mask, cost_argmin] = 1
  179. # get foreground mask inside box and center prior
  180. fg_mask_inboxes = matching_matrix.sum(1) > 0
  181. valid_mask[valid_mask.clone()] = fg_mask_inboxes
  182. matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)
  183. matched_pred_oks = (matching_matrix *
  184. pairwise_oks).sum(1)[fg_mask_inboxes]
  185. return matched_pred_oks, matched_gt_inds