trans_max_iou_assigner.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from mmdet.models.task_modules.assigners.assign_result import AssignResult
  6. from mmdet.models.task_modules.assigners.max_iou_assigner import MaxIoUAssigner
  7. from mmdet.registry import TASK_UTILS
  8. @TASK_UTILS.register_module()
  9. class TransMaxIoUAssigner(MaxIoUAssigner):
  10. def assign(self,
  11. pred_instances: InstanceData,
  12. gt_instances: InstanceData,
  13. gt_instances_ignore: Optional[InstanceData] = None,
  14. **kwargs) -> AssignResult:
  15. """Assign gt to bboxes.
  16. This method assign a gt bbox to every bbox (proposal/anchor), each bbox
  17. will be assigned with -1, or a semi-positive number. -1 means negative
  18. sample, semi-positive number is the index (0-based) of assigned gt.
  19. The assignment is done in following steps, the order matters.
  20. 1. assign every bbox to the background
  21. 2. assign proposals whose iou with all gts < neg_iou_thr to 0
  22. 3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
  23. assign it to that bbox
  24. 4. for each gt bbox, assign its nearest proposals (may be more than
  25. one) to itself
  26. Args:
  27. pred_instances (:obj:`InstanceData`): Instances of model
  28. predictions. It includes ``priors``, and the priors can
  29. be anchors or points, or the bboxes predicted by the
  30. previous stage, has shape (n, 4). The bboxes predicted by
  31. the current model or stage will be named ``bboxes``,
  32. ``labels``, and ``scores``, the same as the ``InstanceData``
  33. in other places.
  34. gt_instances (:obj:`InstanceData`): Ground truth of instance
  35. annotations. It usually includes ``bboxes``, with shape (k, 4),
  36. and ``labels``, with shape (k, ).
  37. gt_instances_ignore (:obj:`InstanceData`, optional): Instances
  38. to be ignored during training. It includes ``bboxes``
  39. attribute data that is ignored during training and testing.
  40. Defaults to None.
  41. Returns:
  42. :obj:`AssignResult`: The assign result.
  43. Example:
  44. >>> from mmengine.structures import InstanceData
  45. >>> self = MaxIoUAssigner(0.5, 0.5)
  46. >>> pred_instances = InstanceData()
  47. >>> pred_instances.priors = torch.Tensor([[0, 0, 10, 10],
  48. ... [10, 10, 20, 20]])
  49. >>> gt_instances = InstanceData()
  50. >>> gt_instances.bboxes = torch.Tensor([[0, 0, 10, 9]])
  51. >>> gt_instances.labels = torch.Tensor([0])
  52. >>> assign_result = self.assign(pred_instances, gt_instances)
  53. >>> expected_gt_inds = torch.LongTensor([1, 0])
  54. >>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
  55. """
  56. gt_bboxes = gt_instances.bboxes
  57. priors = pred_instances.priors
  58. gt_labels = gt_instances.labels
  59. if gt_instances_ignore is not None:
  60. gt_bboxes_ignore = gt_instances_ignore.bboxes
  61. else:
  62. gt_bboxes_ignore = None
  63. assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
  64. gt_bboxes.shape[0] > self.gpu_assign_thr) else False
  65. # compute overlap and assign gt on CPU when number of GT is large
  66. if assign_on_cpu:
  67. device = priors.device
  68. priors = priors.cpu()
  69. gt_bboxes = gt_bboxes.cpu()
  70. gt_labels = gt_labels.cpu()
  71. if gt_bboxes_ignore is not None:
  72. gt_bboxes_ignore = gt_bboxes_ignore.cpu()
  73. trans_priors = torch.cat([
  74. priors[..., 1].view(-1, 1), priors[..., 0].view(-1, 1),
  75. priors[..., 3].view(-1, 1), priors[..., 2].view(-1, 1)
  76. ],
  77. dim=-1)
  78. overlaps = self.iou_calculator(gt_bboxes, trans_priors)
  79. if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
  80. and gt_bboxes_ignore.numel() > 0 and trans_priors.numel() > 0):
  81. if self.ignore_wrt_candidates:
  82. ignore_overlaps = self.iou_calculator(
  83. trans_priors, gt_bboxes_ignore, mode='iof')
  84. ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
  85. else:
  86. ignore_overlaps = self.iou_calculator(
  87. gt_bboxes_ignore, trans_priors, mode='iof')
  88. ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
  89. overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
  90. assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
  91. if assign_on_cpu:
  92. assign_result.gt_inds = assign_result.gt_inds.to(device)
  93. assign_result.max_overlaps = assign_result.max_overlaps.to(device)
  94. if assign_result.labels is not None:
  95. assign_result.labels = assign_result.labels.to(device)
  96. return assign_result