hungarian_assigner.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Union
  3. import torch
  4. from mmengine import ConfigDict
  5. from mmengine.structures import InstanceData
  6. from scipy.optimize import linear_sum_assignment
  7. from torch import Tensor
  8. from mmdet.registry import TASK_UTILS
  9. from .assign_result import AssignResult
  10. from .base_assigner import BaseAssigner
  11. @TASK_UTILS.register_module()
  12. class HungarianAssigner(BaseAssigner):
  13. """Computes one-to-one matching between predictions and ground truth.
  14. This class computes an assignment between the targets and the predictions
  15. based on the costs. The costs are weighted sum of some components.
  16. For DETR the costs are weighted sum of classification cost, regression L1
  17. cost and regression iou cost. The targets don't include the no_object, so
  18. generally there are more predictions than targets. After the one-to-one
  19. matching, the un-matched are treated as backgrounds. Thus each query
  20. prediction will be assigned with `0` or a positive integer indicating the
  21. ground truth index:
  22. - 0: negative sample, no assigned gt
  23. - positive integer: positive sample, index (1-based) of assigned gt
  24. Args:
  25. match_costs (:obj:`ConfigDict` or dict or \
  26. List[Union[:obj:`ConfigDict`, dict]]): Match cost configs.
  27. """
  28. def __init__(
  29. self, match_costs: Union[List[Union[dict, ConfigDict]], dict,
  30. ConfigDict]
  31. ) -> None:
  32. if isinstance(match_costs, dict):
  33. match_costs = [match_costs]
  34. elif isinstance(match_costs, list):
  35. assert len(match_costs) > 0, \
  36. 'match_costs must not be a empty list.'
  37. self.match_costs = [
  38. TASK_UTILS.build(match_cost) for match_cost in match_costs
  39. ]
  40. def assign(self,
  41. pred_instances: InstanceData,
  42. gt_instances: InstanceData,
  43. img_meta: Optional[dict] = None,
  44. **kwargs) -> AssignResult:
  45. """Computes one-to-one matching based on the weighted costs.
  46. This method assign each query prediction to a ground truth or
  47. background. The `assigned_gt_inds` with -1 means don't care,
  48. 0 means negative sample, and positive number is the index (1-based)
  49. of assigned gt.
  50. The assignment is done in the following steps, the order matters.
  51. 1. assign every prediction to -1
  52. 2. compute the weighted costs
  53. 3. do Hungarian matching on CPU based on the costs
  54. 4. assign all to 0 (background) first, then for each matched pair
  55. between predictions and gts, treat this prediction as foreground
  56. and assign the corresponding gt index (plus 1) to it.
  57. Args:
  58. pred_instances (:obj:`InstanceData`): Instances of model
  59. predictions. It includes ``priors``, and the priors can
  60. be anchors or points, or the bboxes predicted by the
  61. previous stage, has shape (n, 4). The bboxes predicted by
  62. the current model or stage will be named ``bboxes``,
  63. ``labels``, and ``scores``, the same as the ``InstanceData``
  64. in other places. It may includes ``masks``, with shape
  65. (n, h, w) or (n, l).
  66. gt_instances (:obj:`InstanceData`): Ground truth of instance
  67. annotations. It usually includes ``bboxes``, with shape (k, 4),
  68. ``labels``, with shape (k, ) and ``masks``, with shape
  69. (k, h, w) or (k, l).
  70. img_meta (dict): Image information.
  71. Returns:
  72. :obj:`AssignResult`: The assigned result.
  73. """
  74. assert isinstance(gt_instances.labels, Tensor)
  75. num_gts, num_preds = len(gt_instances), len(pred_instances)
  76. gt_labels = gt_instances.labels
  77. device = gt_labels.device
  78. # 1. assign -1 by default
  79. assigned_gt_inds = torch.full((num_preds, ),
  80. -1,
  81. dtype=torch.long,
  82. device=device)
  83. assigned_labels = torch.full((num_preds, ),
  84. -1,
  85. dtype=torch.long,
  86. device=device)
  87. if num_gts == 0 or num_preds == 0:
  88. # No ground truth or boxes, return empty assignment
  89. if num_gts == 0:
  90. # No ground truth, assign all to background
  91. assigned_gt_inds[:] = 0
  92. return AssignResult(
  93. num_gts=num_gts,
  94. gt_inds=assigned_gt_inds,
  95. max_overlaps=None,
  96. labels=assigned_labels)
  97. # 2. compute weighted cost
  98. cost_list = []
  99. for match_cost in self.match_costs:
  100. cost = match_cost(
  101. pred_instances=pred_instances,
  102. gt_instances=gt_instances,
  103. img_meta=img_meta)
  104. cost_list.append(cost)
  105. cost = torch.stack(cost_list).sum(dim=0)
  106. # 3. do Hungarian matching on CPU using linear_sum_assignment
  107. cost = cost.detach().cpu()
  108. if linear_sum_assignment is None:
  109. raise ImportError('Please run "pip install scipy" '
  110. 'to install scipy first.')
  111. matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
  112. matched_row_inds = torch.from_numpy(matched_row_inds).to(device)
  113. matched_col_inds = torch.from_numpy(matched_col_inds).to(device)
  114. # 4. assign backgrounds and foregrounds
  115. # assign all indices to backgrounds first
  116. assigned_gt_inds[:] = 0
  117. # assign foregrounds based on matching results
  118. assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
  119. assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
  120. return AssignResult(
  121. num_gts=num_gts,
  122. gt_inds=assigned_gt_inds,
  123. max_overlaps=None,
  124. labels=assigned_labels)