123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List, Optional, Union
- import torch
- from mmengine import ConfigDict
- from mmengine.structures import InstanceData
- from scipy.optimize import linear_sum_assignment
- from torch import Tensor
- from mmdet.registry import TASK_UTILS
- from .assign_result import AssignResult
- from .base_assigner import BaseAssigner
- @TASK_UTILS.register_module()
- class HungarianAssigner(BaseAssigner):
- """Computes one-to-one matching between predictions and ground truth.
- This class computes an assignment between the targets and the predictions
- based on the costs. The costs are weighted sum of some components.
- For DETR the costs are weighted sum of classification cost, regression L1
- cost and regression iou cost. The targets don't include the no_object, so
- generally there are more predictions than targets. After the one-to-one
- matching, the un-matched are treated as backgrounds. Thus each query
- prediction will be assigned with `0` or a positive integer indicating the
- ground truth index:
- - 0: negative sample, no assigned gt
- - positive integer: positive sample, index (1-based) of assigned gt
- Args:
- match_costs (:obj:`ConfigDict` or dict or \
- List[Union[:obj:`ConfigDict`, dict]]): Match cost configs.
- """
- def __init__(
- self, match_costs: Union[List[Union[dict, ConfigDict]], dict,
- ConfigDict]
- ) -> None:
- if isinstance(match_costs, dict):
- match_costs = [match_costs]
- elif isinstance(match_costs, list):
- assert len(match_costs) > 0, \
- 'match_costs must not be a empty list.'
- self.match_costs = [
- TASK_UTILS.build(match_cost) for match_cost in match_costs
- ]
- def assign(self,
- pred_instances: InstanceData,
- gt_instances: InstanceData,
- img_meta: Optional[dict] = None,
- **kwargs) -> AssignResult:
- """Computes one-to-one matching based on the weighted costs.
- This method assign each query prediction to a ground truth or
- background. The `assigned_gt_inds` with -1 means don't care,
- 0 means negative sample, and positive number is the index (1-based)
- of assigned gt.
- The assignment is done in the following steps, the order matters.
- 1. assign every prediction to -1
- 2. compute the weighted costs
- 3. do Hungarian matching on CPU based on the costs
- 4. assign all to 0 (background) first, then for each matched pair
- between predictions and gts, treat this prediction as foreground
- and assign the corresponding gt index (plus 1) to it.
- Args:
- pred_instances (:obj:`InstanceData`): Instances of model
- predictions. It includes ``priors``, and the priors can
- be anchors or points, or the bboxes predicted by the
- previous stage, has shape (n, 4). The bboxes predicted by
- the current model or stage will be named ``bboxes``,
- ``labels``, and ``scores``, the same as the ``InstanceData``
- in other places. It may includes ``masks``, with shape
- (n, h, w) or (n, l).
- gt_instances (:obj:`InstanceData`): Ground truth of instance
- annotations. It usually includes ``bboxes``, with shape (k, 4),
- ``labels``, with shape (k, ) and ``masks``, with shape
- (k, h, w) or (k, l).
- img_meta (dict): Image information.
- Returns:
- :obj:`AssignResult`: The assigned result.
- """
- assert isinstance(gt_instances.labels, Tensor)
- num_gts, num_preds = len(gt_instances), len(pred_instances)
- gt_labels = gt_instances.labels
- device = gt_labels.device
- # 1. assign -1 by default
- assigned_gt_inds = torch.full((num_preds, ),
- -1,
- dtype=torch.long,
- device=device)
- assigned_labels = torch.full((num_preds, ),
- -1,
- dtype=torch.long,
- device=device)
- if num_gts == 0 or num_preds == 0:
- # No ground truth or boxes, return empty assignment
- if num_gts == 0:
- # No ground truth, assign all to background
- assigned_gt_inds[:] = 0
- return AssignResult(
- num_gts=num_gts,
- gt_inds=assigned_gt_inds,
- max_overlaps=None,
- labels=assigned_labels)
- # 2. compute weighted cost
- cost_list = []
- for match_cost in self.match_costs:
- cost = match_cost(
- pred_instances=pred_instances,
- gt_instances=gt_instances,
- img_meta=img_meta)
- cost_list.append(cost)
- cost = torch.stack(cost_list).sum(dim=0)
- # 3. do Hungarian matching on CPU using linear_sum_assignment
- cost = cost.detach().cpu()
- if linear_sum_assignment is None:
- raise ImportError('Please run "pip install scipy" '
- 'to install scipy first.')
- matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
- matched_row_inds = torch.from_numpy(matched_row_inds).to(device)
- matched_col_inds = torch.from_numpy(matched_col_inds).to(device)
- # 4. assign backgrounds and foregrounds
- # assign all indices to backgrounds first
- assigned_gt_inds[:] = 0
- # assign foregrounds based on matching results
- assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
- assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
- return AssignResult(
- num_gts=num_gts,
- gt_inds=assigned_gt_inds,
- max_overlaps=None,
- labels=assigned_labels)
|