# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Modified from https://github.com/ShoufaChen/DiffusionDet/blob/main/diffusiondet/loss.py # noqa # This work is licensed under the CC-BY-NC 4.0 License. # Users should be careful about adopting these features in any commercial matters. # noqa # For more details, please refer to https://github.com/ShoufaChen/DiffusionDet/blob/main/LICENSE # noqa from typing import List, Tuple, Union import torch import torch.nn as nn from mmengine.config import ConfigDict from mmengine.structures import InstanceData from torch import Tensor from mmdet.registry import MODELS, TASK_UTILS from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh from mmdet.utils import ConfigType @TASK_UTILS.register_module() class DiffusionDetCriterion(nn.Module): def __init__( self, num_classes, assigner: Union[ConfigDict, nn.Module], deep_supervision=True, loss_cls=dict( type='FocalLoss', use_sigmoid=True, alpha=0.25, gamma=2.0, reduction='sum', loss_weight=2.0), loss_bbox=dict(type='L1Loss', reduction='sum', loss_weight=5.0), loss_giou=dict(type='GIoULoss', reduction='sum', loss_weight=2.0), ): super().__init__() self.num_classes = num_classes if isinstance(assigner, nn.Module): self.assigner = assigner else: self.assigner = TASK_UTILS.build(assigner) self.deep_supervision = deep_supervision self.loss_cls = MODELS.build(loss_cls) self.loss_bbox = MODELS.build(loss_bbox) self.loss_giou = MODELS.build(loss_giou) def forward(self, outputs, batch_gt_instances, batch_img_metas): batch_indices = self.assigner(outputs, batch_gt_instances, batch_img_metas) # Compute all the requested losses loss_cls = self.loss_classification(outputs, batch_gt_instances, batch_indices) loss_bbox, loss_giou = self.loss_boxes(outputs, batch_gt_instances, batch_indices) losses = dict( loss_cls=loss_cls, loss_bbox=loss_bbox, loss_giou=loss_giou) if self.deep_supervision: assert 'aux_outputs' in outputs for i, aux_outputs in enumerate(outputs['aux_outputs']): batch_indices = self.assigner(aux_outputs, batch_gt_instances, batch_img_metas) loss_cls = self.loss_classification(aux_outputs, batch_gt_instances, batch_indices) loss_bbox, loss_giou = self.loss_boxes(aux_outputs, batch_gt_instances, batch_indices) tmp_losses = dict( loss_cls=loss_cls, loss_bbox=loss_bbox, loss_giou=loss_giou) for name, value in tmp_losses.items(): losses[f's.{i}.{name}'] = value return losses def loss_classification(self, outputs, batch_gt_instances, indices): assert 'pred_logits' in outputs src_logits = outputs['pred_logits'] target_classes_list = [ gt.labels[J] for gt, (_, J) in zip(batch_gt_instances, indices) ] target_classes = torch.full( src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device) for idx in range(len(batch_gt_instances)): target_classes[idx, indices[idx][0]] = target_classes_list[idx] src_logits = src_logits.flatten(0, 1) target_classes = target_classes.flatten(0, 1) # comp focal loss. num_instances = max(torch.cat(target_classes_list).shape[0], 1) loss_cls = self.loss_cls( src_logits, target_classes, ) / num_instances return loss_cls def loss_boxes(self, outputs, batch_gt_instances, indices): assert 'pred_boxes' in outputs pred_boxes = outputs['pred_boxes'] target_bboxes_norm_list = [ gt.norm_bboxes_cxcywh[J] for gt, (_, J) in zip(batch_gt_instances, indices) ] target_bboxes_list = [ gt.bboxes[J] for gt, (_, J) in zip(batch_gt_instances, indices) ] pred_bboxes_list = [] pred_bboxes_norm_list = [] for idx in range(len(batch_gt_instances)): pred_bboxes_list.append(pred_boxes[idx, indices[idx][0]]) image_size = batch_gt_instances[idx].image_size pred_bboxes_norm_list.append(pred_boxes[idx, indices[idx][0]] / image_size) pred_boxes_cat = torch.cat(pred_bboxes_list) pred_boxes_norm_cat = torch.cat(pred_bboxes_norm_list) target_bboxes_cat = torch.cat(target_bboxes_list) target_bboxes_norm_cat = torch.cat(target_bboxes_norm_list) if len(pred_boxes_cat) > 0: num_instances = pred_boxes_cat.shape[0] loss_bbox = self.loss_bbox( pred_boxes_norm_cat, bbox_cxcywh_to_xyxy(target_bboxes_norm_cat)) / num_instances loss_giou = self.loss_giou(pred_boxes_cat, target_bboxes_cat) / num_instances else: loss_bbox = pred_boxes.sum() * 0 loss_giou = pred_boxes.sum() * 0 return loss_bbox, loss_giou @TASK_UTILS.register_module() class DiffusionDetMatcher(nn.Module): """This class computes an assignment between the targets and the predictions of the network For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more predictions than targets. In this case, we do a 1-to-k (dynamic) matching of the best predictions, while the others are un-matched (and thus treated as non-objects). """ def __init__(self, match_costs: Union[List[Union[dict, ConfigDict]], dict, ConfigDict], center_radius: float = 2.5, candidate_topk: int = 5, iou_calculator: ConfigType = dict(type='BboxOverlaps2D'), **kwargs): super().__init__() self.center_radius = center_radius self.candidate_topk = candidate_topk if isinstance(match_costs, dict): match_costs = [match_costs] elif isinstance(match_costs, list): assert len(match_costs) > 0, \ 'match_costs must not be a empty list.' self.use_focal_loss = False self.use_fed_loss = False for _match_cost in match_costs: if _match_cost.get('type') == 'FocalLossCost': self.use_focal_loss = True if _match_cost.get('type') == 'FedLoss': self.use_fed_loss = True raise NotImplementedError self.match_costs = [ TASK_UTILS.build(match_cost) for match_cost in match_costs ] self.iou_calculator = TASK_UTILS.build(iou_calculator) def forward(self, outputs, batch_gt_instances, batch_img_metas): assert 'pred_logits' in outputs and 'pred_boxes' in outputs pred_logits = outputs['pred_logits'] pred_bboxes = outputs['pred_boxes'] batch_size = len(batch_gt_instances) assert batch_size == pred_logits.shape[0] == pred_bboxes.shape[0] batch_indices = [] for i in range(batch_size): pred_instances = InstanceData() pred_instances.bboxes = pred_bboxes[i, ...] pred_instances.scores = pred_logits[i, ...] gt_instances = batch_gt_instances[i] img_meta = batch_img_metas[i] indices = self.single_assigner(pred_instances, gt_instances, img_meta) batch_indices.append(indices) return batch_indices def single_assigner(self, pred_instances, gt_instances, img_meta): with torch.no_grad(): gt_bboxes = gt_instances.bboxes pred_bboxes = pred_instances.bboxes num_gt = gt_bboxes.size(0) if num_gt == 0: # empty object in key frame valid_mask = pred_bboxes.new_zeros((pred_bboxes.shape[0], ), dtype=torch.bool) matched_gt_inds = pred_bboxes.new_zeros((gt_bboxes.shape[0], ), dtype=torch.long) return valid_mask, matched_gt_inds valid_mask, is_in_boxes_and_center = \ self.get_in_gt_and_in_center_info( bbox_xyxy_to_cxcywh(pred_bboxes), bbox_xyxy_to_cxcywh(gt_bboxes) ) cost_list = [] for match_cost in self.match_costs: cost = match_cost( pred_instances=pred_instances, gt_instances=gt_instances, img_meta=img_meta) cost_list.append(cost) pairwise_ious = self.iou_calculator(pred_bboxes, gt_bboxes) cost_list.append((~is_in_boxes_and_center) * 100.0) cost_matrix = torch.stack(cost_list).sum(0) cost_matrix[~valid_mask] = cost_matrix[~valid_mask] + 10000.0 fg_mask_inboxes, matched_gt_inds = \ self.dynamic_k_matching( cost_matrix, pairwise_ious, num_gt) return fg_mask_inboxes, matched_gt_inds def get_in_gt_and_in_center_info( self, pred_bboxes: Tensor, gt_bboxes: Tensor) -> Tuple[Tensor, Tensor]: """Get the information of which prior is in gt bboxes and gt center priors.""" xy_target_gts = bbox_cxcywh_to_xyxy(gt_bboxes) # (x1, y1, x2, y2) pred_bboxes_center_x = pred_bboxes[:, 0].unsqueeze(1) pred_bboxes_center_y = pred_bboxes[:, 1].unsqueeze(1) # whether the center of each anchor is inside a gt box b_l = pred_bboxes_center_x > xy_target_gts[:, 0].unsqueeze(0) b_r = pred_bboxes_center_x < xy_target_gts[:, 2].unsqueeze(0) b_t = pred_bboxes_center_y > xy_target_gts[:, 1].unsqueeze(0) b_b = pred_bboxes_center_y < xy_target_gts[:, 3].unsqueeze(0) # (b_l.long()+b_r.long()+b_t.long()+b_b.long())==4 [300,num_gt] , is_in_boxes = ((b_l.long() + b_r.long() + b_t.long() + b_b.long()) == 4) is_in_boxes_all = is_in_boxes.sum(1) > 0 # [num_query] # in fixed center center_radius = 2.5 # Modified to self-adapted sampling --- the center size depends # on the size of the gt boxes # https://github.com/dulucas/UVO_Challenge/blob/main/Track1/detection/mmdet/core/bbox/assigners/rpn_sim_ota_assigner.py#L212 # noqa b_l = pred_bboxes_center_x > ( gt_bboxes[:, 0] - (center_radius * (xy_target_gts[:, 2] - xy_target_gts[:, 0]))).unsqueeze(0) b_r = pred_bboxes_center_x < ( gt_bboxes[:, 0] + (center_radius * (xy_target_gts[:, 2] - xy_target_gts[:, 0]))).unsqueeze(0) b_t = pred_bboxes_center_y > ( gt_bboxes[:, 1] - (center_radius * (xy_target_gts[:, 3] - xy_target_gts[:, 1]))).unsqueeze(0) b_b = pred_bboxes_center_y < ( gt_bboxes[:, 1] + (center_radius * (xy_target_gts[:, 3] - xy_target_gts[:, 1]))).unsqueeze(0) is_in_centers = ((b_l.long() + b_r.long() + b_t.long() + b_b.long()) == 4) is_in_centers_all = is_in_centers.sum(1) > 0 is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all is_in_boxes_and_center = (is_in_boxes & is_in_centers) return is_in_boxes_anchor, is_in_boxes_and_center def dynamic_k_matching(self, cost: Tensor, pairwise_ious: Tensor, num_gt: int) -> Tuple[Tensor, Tensor]: """Use IoU and matching cost to calculate the dynamic top-k positive targets.""" matching_matrix = torch.zeros_like(cost) # select candidate topk ious for dynamic-k calculation candidate_topk = min(self.candidate_topk, pairwise_ious.size(0)) topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0) # calculate dynamic k for each gt dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1) for gt_idx in range(num_gt): _, pos_idx = torch.topk( cost[:, gt_idx], k=dynamic_ks[gt_idx], largest=False) matching_matrix[:, gt_idx][pos_idx] = 1 del topk_ious, dynamic_ks, pos_idx prior_match_gt_mask = matching_matrix.sum(1) > 1 if prior_match_gt_mask.sum() > 0: _, cost_argmin = torch.min(cost[prior_match_gt_mask, :], dim=1) matching_matrix[prior_match_gt_mask, :] *= 0 matching_matrix[prior_match_gt_mask, cost_argmin] = 1 while (matching_matrix.sum(0) == 0).any(): matched_query_id = matching_matrix.sum(1) > 0 cost[matched_query_id] += 100000.0 unmatch_id = torch.nonzero( matching_matrix.sum(0) == 0, as_tuple=False).squeeze(1) for gt_idx in unmatch_id: pos_idx = torch.argmin(cost[:, gt_idx]) matching_matrix[:, gt_idx][pos_idx] = 1.0 if (matching_matrix.sum(1) > 1).sum() > 0: _, cost_argmin = torch.min(cost[prior_match_gt_mask], dim=1) matching_matrix[prior_match_gt_mask] *= 0 matching_matrix[prior_match_gt_mask, cost_argmin, ] = 1 assert not (matching_matrix.sum(0) == 0).any() # get foreground mask inside box and center prior fg_mask_inboxes = matching_matrix.sum(1) > 0 matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1) return fg_mask_inboxes, matched_gt_inds