123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Optional
- import torch
- from mmengine.structures import InstanceData
- from mmdet.registry import TASK_UTILS
- from .assign_result import AssignResult
- from .base_assigner import BaseAssigner
- @TASK_UTILS.register_module()
- class PointAssigner(BaseAssigner):
- """Assign a corresponding gt bbox or background to each point.
- Each proposals will be assigned with `0`, or a positive integer
- indicating the ground truth index.
- - 0: negative sample, no assigned gt
- - positive integer: positive sample, index (1-based) of assigned gt
- """
- def __init__(self, scale: int = 4, pos_num: int = 3) -> None:
- self.scale = scale
- self.pos_num = pos_num
- def assign(self,
- pred_instances: InstanceData,
- gt_instances: InstanceData,
- gt_instances_ignore: Optional[InstanceData] = None,
- **kwargs) -> AssignResult:
- """Assign gt to points.
- This method assign a gt bbox to every points set, each points set
- will be assigned with the background_label (-1), or a label number.
- -1 is background, and semi-positive number is the index (0-based) of
- assigned gt.
- The assignment is done in following steps, the order matters.
- 1. assign every points to the background_label (-1)
- 2. A point is assigned to some gt bbox if
- (i) the point is within the k closest points to the gt bbox
- (ii) the distance between this point and the gt is smaller than
- other gt bboxes
- Args:
- pred_instances (:obj:`InstanceData`): Instances of model
- predictions. It includes ``priors``, and the priors can
- be anchors or points, or the bboxes predicted by the
- previous stage, has shape (n, 4). The bboxes predicted by
- the current model or stage will be named ``bboxes``,
- ``labels``, and ``scores``, the same as the ``InstanceData``
- in other places.
- gt_instances (:obj:`InstanceData`): Ground truth of instance
- annotations. It usually includes ``bboxes``, with shape (k, 4),
- and ``labels``, with shape (k, ).
- gt_instances_ignore (:obj:`InstanceData`, optional): Instances
- to be ignored during training. It includes ``bboxes``
- attribute data that is ignored during training and testing.
- Defaults to None.
- Returns:
- :obj:`AssignResult`: The assign result.
- """
- gt_bboxes = gt_instances.bboxes
- gt_labels = gt_instances.labels
- # points to be assigned, shape(n, 3) while last
- # dimension stands for (x, y, stride).
- points = pred_instances.priors
- num_points = points.shape[0]
- num_gts = gt_bboxes.shape[0]
- if num_gts == 0 or num_points == 0:
- # If no truth assign everything to the background
- assigned_gt_inds = points.new_full((num_points, ),
- 0,
- dtype=torch.long)
- assigned_labels = points.new_full((num_points, ),
- -1,
- dtype=torch.long)
- return AssignResult(
- num_gts=num_gts,
- gt_inds=assigned_gt_inds,
- max_overlaps=None,
- labels=assigned_labels)
- points_xy = points[:, :2]
- points_stride = points[:, 2]
- points_lvl = torch.log2(
- points_stride).int() # [3...,4...,5...,6...,7...]
- lvl_min, lvl_max = points_lvl.min(), points_lvl.max()
- # assign gt box
- gt_bboxes_xy = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2
- gt_bboxes_wh = (gt_bboxes[:, 2:] - gt_bboxes[:, :2]).clamp(min=1e-6)
- scale = self.scale
- gt_bboxes_lvl = ((torch.log2(gt_bboxes_wh[:, 0] / scale) +
- torch.log2(gt_bboxes_wh[:, 1] / scale)) / 2).int()
- gt_bboxes_lvl = torch.clamp(gt_bboxes_lvl, min=lvl_min, max=lvl_max)
- # stores the assigned gt index of each point
- assigned_gt_inds = points.new_zeros((num_points, ), dtype=torch.long)
- # stores the assigned gt dist (to this point) of each point
- assigned_gt_dist = points.new_full((num_points, ), float('inf'))
- points_range = torch.arange(points.shape[0])
- for idx in range(num_gts):
- gt_lvl = gt_bboxes_lvl[idx]
- # get the index of points in this level
- lvl_idx = gt_lvl == points_lvl
- points_index = points_range[lvl_idx]
- # get the points in this level
- lvl_points = points_xy[lvl_idx, :]
- # get the center point of gt
- gt_point = gt_bboxes_xy[[idx], :]
- # get width and height of gt
- gt_wh = gt_bboxes_wh[[idx], :]
- # compute the distance between gt center and
- # all points in this level
- points_gt_dist = ((lvl_points - gt_point) / gt_wh).norm(dim=1)
- # find the nearest k points to gt center in this level
- min_dist, min_dist_index = torch.topk(
- points_gt_dist, self.pos_num, largest=False)
- # the index of nearest k points to gt center in this level
- min_dist_points_index = points_index[min_dist_index]
- # The less_than_recorded_index stores the index
- # of min_dist that is less then the assigned_gt_dist. Where
- # assigned_gt_dist stores the dist from previous assigned gt
- # (if exist) to each point.
- less_than_recorded_index = min_dist < assigned_gt_dist[
- min_dist_points_index]
- # The min_dist_points_index stores the index of points satisfy:
- # (1) it is k nearest to current gt center in this level.
- # (2) it is closer to current gt center than other gt center.
- min_dist_points_index = min_dist_points_index[
- less_than_recorded_index]
- # assign the result
- assigned_gt_inds[min_dist_points_index] = idx + 1
- assigned_gt_dist[min_dist_points_index] = min_dist[
- less_than_recorded_index]
- assigned_labels = assigned_gt_inds.new_full((num_points, ), -1)
- pos_inds = torch.nonzero(
- assigned_gt_inds > 0, as_tuple=False).squeeze()
- if pos_inds.numel() > 0:
- assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] -
- 1]
- return AssignResult(
- num_gts=num_gts,
- gt_inds=assigned_gt_inds,
- max_overlaps=None,
- labels=assigned_labels)
|