point_assigner.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from mmdet.registry import TASK_UTILS
  6. from .assign_result import AssignResult
  7. from .base_assigner import BaseAssigner
  8. @TASK_UTILS.register_module()
  9. class PointAssigner(BaseAssigner):
  10. """Assign a corresponding gt bbox or background to each point.
  11. Each proposals will be assigned with `0`, or a positive integer
  12. indicating the ground truth index.
  13. - 0: negative sample, no assigned gt
  14. - positive integer: positive sample, index (1-based) of assigned gt
  15. """
  16. def __init__(self, scale: int = 4, pos_num: int = 3) -> None:
  17. self.scale = scale
  18. self.pos_num = pos_num
  19. def assign(self,
  20. pred_instances: InstanceData,
  21. gt_instances: InstanceData,
  22. gt_instances_ignore: Optional[InstanceData] = None,
  23. **kwargs) -> AssignResult:
  24. """Assign gt to points.
  25. This method assign a gt bbox to every points set, each points set
  26. will be assigned with the background_label (-1), or a label number.
  27. -1 is background, and semi-positive number is the index (0-based) of
  28. assigned gt.
  29. The assignment is done in following steps, the order matters.
  30. 1. assign every points to the background_label (-1)
  31. 2. A point is assigned to some gt bbox if
  32. (i) the point is within the k closest points to the gt bbox
  33. (ii) the distance between this point and the gt is smaller than
  34. other gt bboxes
  35. Args:
  36. pred_instances (:obj:`InstanceData`): Instances of model
  37. predictions. It includes ``priors``, and the priors can
  38. be anchors or points, or the bboxes predicted by the
  39. previous stage, has shape (n, 4). The bboxes predicted by
  40. the current model or stage will be named ``bboxes``,
  41. ``labels``, and ``scores``, the same as the ``InstanceData``
  42. in other places.
  43. gt_instances (:obj:`InstanceData`): Ground truth of instance
  44. annotations. It usually includes ``bboxes``, with shape (k, 4),
  45. and ``labels``, with shape (k, ).
  46. gt_instances_ignore (:obj:`InstanceData`, optional): Instances
  47. to be ignored during training. It includes ``bboxes``
  48. attribute data that is ignored during training and testing.
  49. Defaults to None.
  50. Returns:
  51. :obj:`AssignResult`: The assign result.
  52. """
  53. gt_bboxes = gt_instances.bboxes
  54. gt_labels = gt_instances.labels
  55. # points to be assigned, shape(n, 3) while last
  56. # dimension stands for (x, y, stride).
  57. points = pred_instances.priors
  58. num_points = points.shape[0]
  59. num_gts = gt_bboxes.shape[0]
  60. if num_gts == 0 or num_points == 0:
  61. # If no truth assign everything to the background
  62. assigned_gt_inds = points.new_full((num_points, ),
  63. 0,
  64. dtype=torch.long)
  65. assigned_labels = points.new_full((num_points, ),
  66. -1,
  67. dtype=torch.long)
  68. return AssignResult(
  69. num_gts=num_gts,
  70. gt_inds=assigned_gt_inds,
  71. max_overlaps=None,
  72. labels=assigned_labels)
  73. points_xy = points[:, :2]
  74. points_stride = points[:, 2]
  75. points_lvl = torch.log2(
  76. points_stride).int() # [3...,4...,5...,6...,7...]
  77. lvl_min, lvl_max = points_lvl.min(), points_lvl.max()
  78. # assign gt box
  79. gt_bboxes_xy = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2
  80. gt_bboxes_wh = (gt_bboxes[:, 2:] - gt_bboxes[:, :2]).clamp(min=1e-6)
  81. scale = self.scale
  82. gt_bboxes_lvl = ((torch.log2(gt_bboxes_wh[:, 0] / scale) +
  83. torch.log2(gt_bboxes_wh[:, 1] / scale)) / 2).int()
  84. gt_bboxes_lvl = torch.clamp(gt_bboxes_lvl, min=lvl_min, max=lvl_max)
  85. # stores the assigned gt index of each point
  86. assigned_gt_inds = points.new_zeros((num_points, ), dtype=torch.long)
  87. # stores the assigned gt dist (to this point) of each point
  88. assigned_gt_dist = points.new_full((num_points, ), float('inf'))
  89. points_range = torch.arange(points.shape[0])
  90. for idx in range(num_gts):
  91. gt_lvl = gt_bboxes_lvl[idx]
  92. # get the index of points in this level
  93. lvl_idx = gt_lvl == points_lvl
  94. points_index = points_range[lvl_idx]
  95. # get the points in this level
  96. lvl_points = points_xy[lvl_idx, :]
  97. # get the center point of gt
  98. gt_point = gt_bboxes_xy[[idx], :]
  99. # get width and height of gt
  100. gt_wh = gt_bboxes_wh[[idx], :]
  101. # compute the distance between gt center and
  102. # all points in this level
  103. points_gt_dist = ((lvl_points - gt_point) / gt_wh).norm(dim=1)
  104. # find the nearest k points to gt center in this level
  105. min_dist, min_dist_index = torch.topk(
  106. points_gt_dist, self.pos_num, largest=False)
  107. # the index of nearest k points to gt center in this level
  108. min_dist_points_index = points_index[min_dist_index]
  109. # The less_than_recorded_index stores the index
  110. # of min_dist that is less then the assigned_gt_dist. Where
  111. # assigned_gt_dist stores the dist from previous assigned gt
  112. # (if exist) to each point.
  113. less_than_recorded_index = min_dist < assigned_gt_dist[
  114. min_dist_points_index]
  115. # The min_dist_points_index stores the index of points satisfy:
  116. # (1) it is k nearest to current gt center in this level.
  117. # (2) it is closer to current gt center than other gt center.
  118. min_dist_points_index = min_dist_points_index[
  119. less_than_recorded_index]
  120. # assign the result
  121. assigned_gt_inds[min_dist_points_index] = idx + 1
  122. assigned_gt_dist[min_dist_points_index] = min_dist[
  123. less_than_recorded_index]
  124. assigned_labels = assigned_gt_inds.new_full((num_points, ), -1)
  125. pos_inds = torch.nonzero(
  126. assigned_gt_inds > 0, as_tuple=False).squeeze()
  127. if pos_inds.numel() > 0:
  128. assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] -
  129. 1]
  130. return AssignResult(
  131. num_gts=num_gts,
  132. gt_inds=assigned_gt_inds,
  133. max_overlaps=None,
  134. labels=assigned_labels)