atss_assigner.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. from typing import List, Optional
  4. import torch
  5. from mmengine.structures import InstanceData
  6. from torch import Tensor
  7. from mmdet.registry import TASK_UTILS
  8. from mmdet.utils import ConfigType
  9. from .assign_result import AssignResult
  10. from .base_assigner import BaseAssigner
  11. def bbox_center_distance(bboxes: Tensor, priors: Tensor) -> Tensor:
  12. """Compute the center distance between bboxes and priors.
  13. Args:
  14. bboxes (Tensor): Shape (n, 4) for , "xyxy" format.
  15. priors (Tensor): Shape (n, 4) for priors, "xyxy" format.
  16. Returns:
  17. Tensor: Center distances between bboxes and priors.
  18. """
  19. bbox_cx = (bboxes[:, 0] + bboxes[:, 2]) / 2.0
  20. bbox_cy = (bboxes[:, 1] + bboxes[:, 3]) / 2.0
  21. bbox_points = torch.stack((bbox_cx, bbox_cy), dim=1)
  22. priors_cx = (priors[:, 0] + priors[:, 2]) / 2.0
  23. priors_cy = (priors[:, 1] + priors[:, 3]) / 2.0
  24. priors_points = torch.stack((priors_cx, priors_cy), dim=1)
  25. distances = (priors_points[:, None, :] -
  26. bbox_points[None, :, :]).pow(2).sum(-1).sqrt()
  27. return distances
  28. @TASK_UTILS.register_module()
  29. class ATSSAssigner(BaseAssigner):
  30. """Assign a corresponding gt bbox or background to each prior.
  31. Each proposals will be assigned with `0` or a positive integer
  32. indicating the ground truth index.
  33. - 0: negative sample, no assigned gt
  34. - positive integer: positive sample, index (1-based) of assigned gt
  35. If ``alpha`` is not None, it means that the dynamic cost
  36. ATSSAssigner is adopted, which is currently only used in the DDOD.
  37. Args:
  38. topk (int): number of priors selected in each level
  39. alpha (float, optional): param of cost rate for each proposal only
  40. in DDOD. Defaults to None.
  41. iou_calculator (:obj:`ConfigDict` or dict): Config dict for iou
  42. calculator. Defaults to ``dict(type='BboxOverlaps2D')``
  43. ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
  44. `gt_bboxes_ignore` is specified). Negative values mean not
  45. ignoring any bboxes. Defaults to -1.
  46. """
  47. def __init__(self,
  48. topk: int,
  49. alpha: Optional[float] = None,
  50. iou_calculator: ConfigType = dict(type='BboxOverlaps2D'),
  51. ignore_iof_thr: float = -1) -> None:
  52. self.topk = topk
  53. self.alpha = alpha
  54. self.iou_calculator = TASK_UTILS.build(iou_calculator)
  55. self.ignore_iof_thr = ignore_iof_thr
  56. # https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py
  57. def assign(
  58. self,
  59. pred_instances: InstanceData,
  60. num_level_priors: List[int],
  61. gt_instances: InstanceData,
  62. gt_instances_ignore: Optional[InstanceData] = None
  63. ) -> AssignResult:
  64. """Assign gt to priors.
  65. The assignment is done in following steps
  66. 1. compute iou between all prior (prior of all pyramid levels) and gt
  67. 2. compute center distance between all prior and gt
  68. 3. on each pyramid level, for each gt, select k prior whose center
  69. are closest to the gt center, so we total select k*l prior as
  70. candidates for each gt
  71. 4. get corresponding iou for the these candidates, and compute the
  72. mean and std, set mean + std as the iou threshold
  73. 5. select these candidates whose iou are greater than or equal to
  74. the threshold as positive
  75. 6. limit the positive sample's center in gt
  76. If ``alpha`` is not None, and ``cls_scores`` and `bbox_preds`
  77. are not None, the overlaps calculation in the first step
  78. will also include dynamic cost, which is currently only used in
  79. the DDOD.
  80. Args:
  81. pred_instances (:obj:`InstaceData`): Instances of model
  82. predictions. It includes ``priors``, and the priors can
  83. be anchors, points, or bboxes predicted by the model,
  84. shape(n, 4).
  85. num_level_priors (List): Number of bboxes in each level
  86. gt_instances (:obj:`InstaceData`): Ground truth of instance
  87. annotations. It usually includes ``bboxes`` and ``labels``
  88. attributes.
  89. gt_instances_ignore (:obj:`InstaceData`, optional): Instances
  90. to be ignored during training. It includes ``bboxes``
  91. attribute data that is ignored during training and testing.
  92. Defaults to None.
  93. Returns:
  94. :obj:`AssignResult`: The assign result.
  95. """
  96. gt_bboxes = gt_instances.bboxes
  97. priors = pred_instances.priors
  98. gt_labels = gt_instances.labels
  99. if gt_instances_ignore is not None:
  100. gt_bboxes_ignore = gt_instances_ignore.bboxes
  101. else:
  102. gt_bboxes_ignore = None
  103. INF = 100000000
  104. priors = priors[:, :4]
  105. num_gt, num_priors = gt_bboxes.size(0), priors.size(0)
  106. message = 'Invalid alpha parameter because cls_scores or ' \
  107. 'bbox_preds are None. If you want to use the ' \
  108. 'cost-based ATSSAssigner, please set cls_scores, ' \
  109. 'bbox_preds and self.alpha at the same time. '
  110. # compute iou between all bbox and gt
  111. if self.alpha is None:
  112. # ATSSAssigner
  113. overlaps = self.iou_calculator(priors, gt_bboxes)
  114. if ('scores' in pred_instances or 'bboxes' in pred_instances):
  115. warnings.warn(message)
  116. else:
  117. # Dynamic cost ATSSAssigner in DDOD
  118. assert ('scores' in pred_instances
  119. and 'bboxes' in pred_instances), message
  120. cls_scores = pred_instances.scores
  121. bbox_preds = pred_instances.bboxes
  122. # compute cls cost for bbox and GT
  123. cls_cost = torch.sigmoid(cls_scores[:, gt_labels])
  124. # compute iou between all bbox and gt
  125. overlaps = self.iou_calculator(bbox_preds, gt_bboxes)
  126. # make sure that we are in element-wise multiplication
  127. assert cls_cost.shape == overlaps.shape
  128. # overlaps is actually a cost matrix
  129. overlaps = cls_cost**(1 - self.alpha) * overlaps**self.alpha
  130. # assign 0 by default
  131. assigned_gt_inds = overlaps.new_full((num_priors, ),
  132. 0,
  133. dtype=torch.long)
  134. if num_gt == 0 or num_priors == 0:
  135. # No ground truth or boxes, return empty assignment
  136. max_overlaps = overlaps.new_zeros((num_priors, ))
  137. if num_gt == 0:
  138. # No truth, assign everything to background
  139. assigned_gt_inds[:] = 0
  140. assigned_labels = overlaps.new_full((num_priors, ),
  141. -1,
  142. dtype=torch.long)
  143. return AssignResult(
  144. num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
  145. # compute center distance between all bbox and gt
  146. distances = bbox_center_distance(gt_bboxes, priors)
  147. if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
  148. and gt_bboxes_ignore.numel() > 0 and priors.numel() > 0):
  149. ignore_overlaps = self.iou_calculator(
  150. priors, gt_bboxes_ignore, mode='iof')
  151. ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
  152. ignore_idxs = ignore_max_overlaps > self.ignore_iof_thr
  153. distances[ignore_idxs, :] = INF
  154. assigned_gt_inds[ignore_idxs] = -1
  155. # Selecting candidates based on the center distance
  156. candidate_idxs = []
  157. start_idx = 0
  158. for level, priors_per_level in enumerate(num_level_priors):
  159. # on each pyramid level, for each gt,
  160. # select k bbox whose center are closest to the gt center
  161. end_idx = start_idx + priors_per_level
  162. distances_per_level = distances[start_idx:end_idx, :]
  163. selectable_k = min(self.topk, priors_per_level)
  164. _, topk_idxs_per_level = distances_per_level.topk(
  165. selectable_k, dim=0, largest=False)
  166. candidate_idxs.append(topk_idxs_per_level + start_idx)
  167. start_idx = end_idx
  168. candidate_idxs = torch.cat(candidate_idxs, dim=0)
  169. # get corresponding iou for the these candidates, and compute the
  170. # mean and std, set mean + std as the iou threshold
  171. candidate_overlaps = overlaps[candidate_idxs, torch.arange(num_gt)]
  172. overlaps_mean_per_gt = candidate_overlaps.mean(0)
  173. overlaps_std_per_gt = candidate_overlaps.std(0)
  174. overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
  175. is_pos = candidate_overlaps >= overlaps_thr_per_gt[None, :]
  176. # limit the positive sample's center in gt
  177. for gt_idx in range(num_gt):
  178. candidate_idxs[:, gt_idx] += gt_idx * num_priors
  179. priors_cx = (priors[:, 0] + priors[:, 2]) / 2.0
  180. priors_cy = (priors[:, 1] + priors[:, 3]) / 2.0
  181. ep_priors_cx = priors_cx.view(1, -1).expand(
  182. num_gt, num_priors).contiguous().view(-1)
  183. ep_priors_cy = priors_cy.view(1, -1).expand(
  184. num_gt, num_priors).contiguous().view(-1)
  185. candidate_idxs = candidate_idxs.view(-1)
  186. # calculate the left, top, right, bottom distance between positive
  187. # prior center and gt side
  188. l_ = ep_priors_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0]
  189. t_ = ep_priors_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1]
  190. r_ = gt_bboxes[:, 2] - ep_priors_cx[candidate_idxs].view(-1, num_gt)
  191. b_ = gt_bboxes[:, 3] - ep_priors_cy[candidate_idxs].view(-1, num_gt)
  192. is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01
  193. is_pos = is_pos & is_in_gts
  194. # if an anchor box is assigned to multiple gts,
  195. # the one with the highest IoU will be selected.
  196. overlaps_inf = torch.full_like(overlaps,
  197. -INF).t().contiguous().view(-1)
  198. index = candidate_idxs.view(-1)[is_pos.view(-1)]
  199. overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index]
  200. overlaps_inf = overlaps_inf.view(num_gt, -1).t()
  201. max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1)
  202. assigned_gt_inds[
  203. max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1
  204. assigned_labels = assigned_gt_inds.new_full((num_priors, ), -1)
  205. pos_inds = torch.nonzero(
  206. assigned_gt_inds > 0, as_tuple=False).squeeze()
  207. if pos_inds.numel() > 0:
  208. assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] -
  209. 1]
  210. return AssignResult(
  211. num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)