123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Optional
- import torch
- from mmengine.structures import InstanceData
- from mmdet.models.task_modules.assigners.assign_result import AssignResult
- from mmdet.models.task_modules.assigners.max_iou_assigner import MaxIoUAssigner
- from mmdet.registry import TASK_UTILS
- @TASK_UTILS.register_module()
- class TransMaxIoUAssigner(MaxIoUAssigner):
- def assign(self,
- pred_instances: InstanceData,
- gt_instances: InstanceData,
- gt_instances_ignore: Optional[InstanceData] = None,
- **kwargs) -> AssignResult:
- """Assign gt to bboxes.
- This method assign a gt bbox to every bbox (proposal/anchor), each bbox
- will be assigned with -1, or a semi-positive number. -1 means negative
- sample, semi-positive number is the index (0-based) of assigned gt.
- The assignment is done in following steps, the order matters.
- 1. assign every bbox to the background
- 2. assign proposals whose iou with all gts < neg_iou_thr to 0
- 3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
- assign it to that bbox
- 4. for each gt bbox, assign its nearest proposals (may be more than
- one) to itself
- Args:
- pred_instances (:obj:`InstanceData`): Instances of model
- predictions. It includes ``priors``, and the priors can
- be anchors or points, or the bboxes predicted by the
- previous stage, has shape (n, 4). The bboxes predicted by
- the current model or stage will be named ``bboxes``,
- ``labels``, and ``scores``, the same as the ``InstanceData``
- in other places.
- gt_instances (:obj:`InstanceData`): Ground truth of instance
- annotations. It usually includes ``bboxes``, with shape (k, 4),
- and ``labels``, with shape (k, ).
- gt_instances_ignore (:obj:`InstanceData`, optional): Instances
- to be ignored during training. It includes ``bboxes``
- attribute data that is ignored during training and testing.
- Defaults to None.
- Returns:
- :obj:`AssignResult`: The assign result.
- Example:
- >>> from mmengine.structures import InstanceData
- >>> self = MaxIoUAssigner(0.5, 0.5)
- >>> pred_instances = InstanceData()
- >>> pred_instances.priors = torch.Tensor([[0, 0, 10, 10],
- ... [10, 10, 20, 20]])
- >>> gt_instances = InstanceData()
- >>> gt_instances.bboxes = torch.Tensor([[0, 0, 10, 9]])
- >>> gt_instances.labels = torch.Tensor([0])
- >>> assign_result = self.assign(pred_instances, gt_instances)
- >>> expected_gt_inds = torch.LongTensor([1, 0])
- >>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
- """
- gt_bboxes = gt_instances.bboxes
- priors = pred_instances.priors
- gt_labels = gt_instances.labels
- if gt_instances_ignore is not None:
- gt_bboxes_ignore = gt_instances_ignore.bboxes
- else:
- gt_bboxes_ignore = None
- assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
- gt_bboxes.shape[0] > self.gpu_assign_thr) else False
- # compute overlap and assign gt on CPU when number of GT is large
- if assign_on_cpu:
- device = priors.device
- priors = priors.cpu()
- gt_bboxes = gt_bboxes.cpu()
- gt_labels = gt_labels.cpu()
- if gt_bboxes_ignore is not None:
- gt_bboxes_ignore = gt_bboxes_ignore.cpu()
- trans_priors = torch.cat([
- priors[..., 1].view(-1, 1), priors[..., 0].view(-1, 1),
- priors[..., 3].view(-1, 1), priors[..., 2].view(-1, 1)
- ],
- dim=-1)
- overlaps = self.iou_calculator(gt_bboxes, trans_priors)
- if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
- and gt_bboxes_ignore.numel() > 0 and trans_priors.numel() > 0):
- if self.ignore_wrt_candidates:
- ignore_overlaps = self.iou_calculator(
- trans_priors, gt_bboxes_ignore, mode='iof')
- ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
- else:
- ignore_overlaps = self.iou_calculator(
- gt_bboxes_ignore, trans_priors, mode='iof')
- ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
- overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
- assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
- if assign_on_cpu:
- assign_result.gt_inds = assign_result.gt_inds.to(device)
- assign_result.max_overlaps = assign_result.max_overlaps.to(device)
- if assign_result.labels is not None:
- assign_result.labels = assign_result.labels.to(device)
- return assign_result
|