# Copyright (c) OpenMMLab. All rights reserved. from typing import Optional, Tuple import torch import torch.nn.functional as F from mmengine.structures import InstanceData from torch import Tensor from mmdet.registry import TASK_UTILS from mmdet.structures.bbox import BaseBoxes from mmdet.utils import ConfigType from .assign_result import AssignResult from .base_assigner import BaseAssigner INF = 100000000 EPS = 1.0e-7 def center_of_mass(masks: Tensor, eps: float = 1e-7) -> Tensor: """Compute the masks center of mass. Args: masks: Mask tensor, has shape (num_masks, H, W). eps: a small number to avoid normalizer to be zero. Defaults to 1e-7. Returns: Tensor: The masks center of mass. Has shape (num_masks, 2). """ n, h, w = masks.shape grid_h = torch.arange(h, device=masks.device)[:, None] grid_w = torch.arange(w, device=masks.device) normalizer = masks.sum(dim=(1, 2)).float().clamp(min=eps) center_y = (masks * grid_h).sum(dim=(1, 2)) / normalizer center_x = (masks * grid_w).sum(dim=(1, 2)) / normalizer center = torch.cat([center_x[:, None], center_y[:, None]], dim=1) return center @TASK_UTILS.register_module() class DynamicSoftLabelAssigner(BaseAssigner): """Computes matching between predictions and ground truth with dynamic soft label assignment. Args: soft_center_radius (float): Radius of the soft center prior. Defaults to 3.0. topk (int): Select top-k predictions to calculate dynamic k best matches for each gt. Defaults to 13. iou_weight (float): The scale factor of iou cost. Defaults to 3.0. iou_calculator (ConfigType): Config of overlaps Calculator. Defaults to dict(type='BboxOverlaps2D'). """ def __init__( self, soft_center_radius: float = 3.0, topk: int = 13, iou_weight: float = 3.0, iou_calculator: ConfigType = dict(type='BboxOverlaps2D') ) -> None: self.soft_center_radius = soft_center_radius self.topk = topk self.iou_weight = iou_weight self.iou_calculator = TASK_UTILS.build(iou_calculator) def assign(self, pred_instances: InstanceData, gt_instances: InstanceData, gt_instances_ignore: Optional[InstanceData] = None, **kwargs) -> AssignResult: """Assign gt to priors. Args: pred_instances (:obj:`InstanceData`): Instances of model predictions. It includes ``priors``, and the priors can be anchors or points, or the bboxes predicted by the previous stage, has shape (n, 4). The bboxes predicted by the current model or stage will be named ``bboxes``, ``labels``, and ``scores``, the same as the ``InstanceData`` in other places. gt_instances (:obj:`InstanceData`): Ground truth of instance annotations. It usually includes ``bboxes``, with shape (k, 4), and ``labels``, with shape (k, ). gt_instances_ignore (:obj:`InstanceData`, optional): Instances to be ignored during training. It includes ``bboxes`` attribute data that is ignored during training and testing. Defaults to None. Returns: obj:`AssignResult`: The assigned result. """ gt_bboxes = gt_instances.bboxes gt_labels = gt_instances.labels num_gt = gt_bboxes.size(0) decoded_bboxes = pred_instances.bboxes pred_scores = pred_instances.scores priors = pred_instances.priors num_bboxes = decoded_bboxes.size(0) # assign 0 by default assigned_gt_inds = decoded_bboxes.new_full((num_bboxes, ), 0, dtype=torch.long) if num_gt == 0 or num_bboxes == 0: # No ground truth or boxes, return empty assignment max_overlaps = decoded_bboxes.new_zeros((num_bboxes, )) if num_gt == 0: # No truth, assign everything to background assigned_gt_inds[:] = 0 assigned_labels = decoded_bboxes.new_full((num_bboxes, ), -1, dtype=torch.long) return AssignResult( num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) prior_center = priors[:, :2] if isinstance(gt_bboxes, BaseBoxes): is_in_gts = gt_bboxes.find_inside_points(prior_center) else: # Tensor boxes will be treated as horizontal boxes by defaults lt_ = prior_center[:, None] - gt_bboxes[:, :2] rb_ = gt_bboxes[:, 2:] - prior_center[:, None] deltas = torch.cat([lt_, rb_], dim=-1) is_in_gts = deltas.min(dim=-1).values > 0 valid_mask = is_in_gts.sum(dim=1) > 0 valid_decoded_bbox = decoded_bboxes[valid_mask] valid_pred_scores = pred_scores[valid_mask] num_valid = valid_decoded_bbox.size(0) if num_valid == 0: # No ground truth or boxes, return empty assignment max_overlaps = decoded_bboxes.new_zeros((num_bboxes, )) assigned_labels = decoded_bboxes.new_full((num_bboxes, ), -1, dtype=torch.long) return AssignResult( num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) if hasattr(gt_instances, 'masks'): gt_center = center_of_mass(gt_instances.masks, eps=EPS) elif isinstance(gt_bboxes, BaseBoxes): gt_center = gt_bboxes.centers else: # Tensor boxes will be treated as horizontal boxes by defaults gt_center = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2.0 valid_prior = priors[valid_mask] strides = valid_prior[:, 2] distance = (valid_prior[:, None, :2] - gt_center[None, :, :] ).pow(2).sum(-1).sqrt() / strides[:, None] soft_center_prior = torch.pow(10, distance - self.soft_center_radius) pairwise_ious = self.iou_calculator(valid_decoded_bbox, gt_bboxes) iou_cost = -torch.log(pairwise_ious + EPS) * self.iou_weight gt_onehot_label = ( F.one_hot(gt_labels.to(torch.int64), pred_scores.shape[-1]).float().unsqueeze(0).repeat( num_valid, 1, 1)) valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1) soft_label = gt_onehot_label * pairwise_ious[..., None] scale_factor = soft_label - valid_pred_scores.sigmoid() soft_cls_cost = F.binary_cross_entropy_with_logits( valid_pred_scores, soft_label, reduction='none') * scale_factor.abs().pow(2.0) soft_cls_cost = soft_cls_cost.sum(dim=-1) cost_matrix = soft_cls_cost + iou_cost + soft_center_prior matched_pred_ious, matched_gt_inds = self.dynamic_k_matching( cost_matrix, pairwise_ious, num_gt, valid_mask) # convert to AssignResult format assigned_gt_inds[valid_mask] = matched_gt_inds + 1 assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1) assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long() max_overlaps = assigned_gt_inds.new_full((num_bboxes, ), -INF, dtype=torch.float32) max_overlaps[valid_mask] = matched_pred_ious return AssignResult( num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) def dynamic_k_matching(self, cost: Tensor, pairwise_ious: Tensor, num_gt: int, valid_mask: Tensor) -> Tuple[Tensor, Tensor]: """Use IoU and matching cost to calculate the dynamic top-k positive targets. Same as SimOTA. Args: cost (Tensor): Cost matrix. pairwise_ious (Tensor): Pairwise iou matrix. num_gt (int): Number of gt. valid_mask (Tensor): Mask for valid bboxes. Returns: tuple: matched ious and gt indexes. """ matching_matrix = torch.zeros_like(cost, dtype=torch.uint8) # select candidate topk ious for dynamic-k calculation candidate_topk = min(self.topk, pairwise_ious.size(0)) topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0) # calculate dynamic k for each gt dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1) for gt_idx in range(num_gt): _, pos_idx = torch.topk( cost[:, gt_idx], k=dynamic_ks[gt_idx], largest=False) matching_matrix[:, gt_idx][pos_idx] = 1 del topk_ious, dynamic_ks, pos_idx prior_match_gt_mask = matching_matrix.sum(1) > 1 if prior_match_gt_mask.sum() > 0: cost_min, cost_argmin = torch.min( cost[prior_match_gt_mask, :], dim=1) matching_matrix[prior_match_gt_mask, :] *= 0 matching_matrix[prior_match_gt_mask, cost_argmin] = 1 # get foreground mask inside box and center prior fg_mask_inboxes = matching_matrix.sum(1) > 0 valid_mask[valid_mask.clone()] = fg_mask_inboxes matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1) matched_pred_ious = (matching_matrix * pairwise_ious).sum(1)[fg_mask_inboxes] return matched_pred_ious, matched_gt_inds