123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249 |
- # Copyright (c) Tianheng Cheng and its affiliates. All Rights Reserved
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from scipy.optimize import linear_sum_assignment
- from torch.cuda.amp import autocast
- from mmdet.registry import MODELS, TASK_UTILS
- from mmdet.utils import reduce_mean
- def compute_mask_iou(inputs, targets):
- inputs = inputs.sigmoid()
- # thresholding
- binarized_inputs = (inputs >= 0.4).float()
- targets = (targets > 0.5).float()
- intersection = (binarized_inputs * targets).sum(-1)
- union = targets.sum(-1) + binarized_inputs.sum(-1) - intersection
- score = intersection / (union + 1e-6)
- return score
- def dice_score(inputs, targets):
- inputs = inputs.sigmoid()
- numerator = 2 * torch.matmul(inputs, targets.t())
- denominator = (inputs * inputs).sum(-1)[:,
- None] + (targets * targets).sum(-1)
- score = numerator / (denominator + 1e-4)
- return score
- @MODELS.register_module()
- class SparseInstCriterion(nn.Module):
- """This part is partially derivated from:
- https://github.com/facebookresearch/detr/blob/main/models/detr.py.
- """
- def __init__(
- self,
- num_classes,
- assigner,
- loss_cls=dict(
- type='FocalLoss',
- use_sigmoid=True,
- alpha=0.25,
- gamma=2.0,
- reduction='sum',
- loss_weight=2.0),
- loss_obj=dict(
- type='CrossEntropyLoss',
- use_sigmoid=True,
- reduction='mean',
- loss_weight=1.0),
- loss_mask=dict(
- type='CrossEntropyLoss',
- use_sigmoid=True,
- reduction='mean',
- loss_weight=5.0),
- loss_dice=dict(
- type='DiceLoss',
- use_sigmoid=True,
- reduction='sum',
- eps=5e-5,
- loss_weight=2.0),
- ):
- super().__init__()
- self.matcher = TASK_UTILS.build(assigner)
- self.num_classes = num_classes
- self.loss_cls = MODELS.build(loss_cls)
- self.loss_obj = MODELS.build(loss_obj)
- self.loss_mask = MODELS.build(loss_mask)
- self.loss_dice = MODELS.build(loss_dice)
- def _get_src_permutation_idx(self, indices):
- # permute predictions following indices
- batch_idx = torch.cat(
- [torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
- src_idx = torch.cat([src for (src, _) in indices])
- return batch_idx, src_idx
- def _get_tgt_permutation_idx(self, indices):
- # permute targets following indices
- batch_idx = torch.cat(
- [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
- tgt_idx = torch.cat([tgt for (_, tgt) in indices])
- return batch_idx, tgt_idx
- def loss_classification(self, outputs, batch_gt_instances, indices,
- num_instances):
- assert 'pred_logits' in outputs
- src_logits = outputs['pred_logits']
- idx = self._get_src_permutation_idx(indices)
- target_classes_o = torch.cat(
- [gt.labels[J] for gt, (_, J) in zip(batch_gt_instances, indices)])
- target_classes = torch.full(
- src_logits.shape[:2],
- self.num_classes,
- dtype=torch.int64,
- device=src_logits.device)
- target_classes[idx] = target_classes_o
- src_logits = src_logits.flatten(0, 1)
- target_classes = target_classes.flatten(0, 1)
- # comp focal loss.
- class_loss = self.loss_cls(
- src_logits,
- target_classes,
- ) / num_instances
- return class_loss
- def loss_masks_with_iou_objectness(self, outputs, batch_gt_instances,
- indices, num_instances):
- src_idx = self._get_src_permutation_idx(indices)
- tgt_idx = self._get_tgt_permutation_idx(indices)
- # Bx100xHxW
- assert 'pred_masks' in outputs
- assert 'pred_scores' in outputs
- src_iou_scores = outputs['pred_scores']
- src_masks = outputs['pred_masks']
- with torch.no_grad():
- target_masks = torch.cat([
- gt.masks.to_tensor(
- dtype=src_masks.dtype, device=src_masks.device)
- for gt in batch_gt_instances
- ])
- num_masks = [len(gt.masks) for gt in batch_gt_instances]
- target_masks = target_masks.to(src_masks)
- if len(target_masks) == 0:
- loss_dice = src_masks.sum() * 0.0
- loss_mask = src_masks.sum() * 0.0
- loss_objectness = src_iou_scores.sum() * 0.0
- return loss_objectness, loss_dice, loss_mask
- src_masks = src_masks[src_idx]
- target_masks = F.interpolate(
- target_masks[:, None],
- size=src_masks.shape[-2:],
- mode='bilinear',
- align_corners=False).squeeze(1)
- src_masks = src_masks.flatten(1)
- # FIXME: tgt_idx
- mix_tgt_idx = torch.zeros_like(tgt_idx[1])
- cum_sum = 0
- for num_mask in num_masks:
- mix_tgt_idx[cum_sum:cum_sum + num_mask] = cum_sum
- cum_sum += num_mask
- mix_tgt_idx += tgt_idx[1]
- target_masks = target_masks[mix_tgt_idx].flatten(1)
- with torch.no_grad():
- ious = compute_mask_iou(src_masks, target_masks)
- tgt_iou_scores = ious
- src_iou_scores = src_iou_scores[src_idx]
- tgt_iou_scores = tgt_iou_scores.flatten(0)
- src_iou_scores = src_iou_scores.flatten(0)
- loss_objectness = self.loss_obj(src_iou_scores, tgt_iou_scores)
- loss_dice = self.loss_dice(src_masks, target_masks) / num_instances
- loss_mask = self.loss_mask(src_masks, target_masks)
- return loss_objectness, loss_dice, loss_mask
- def forward(self, outputs, batch_gt_instances, batch_img_metas,
- batch_gt_instances_ignore):
- # Retrieve the matching between the outputs of
- # the last layer and the targets
- indices = self.matcher(outputs, batch_gt_instances)
- # Compute the average number of target boxes
- # across all nodes, for normalization purposes
- num_instances = sum(gt.labels.shape[0] for gt in batch_gt_instances)
- num_instances = torch.as_tensor([num_instances],
- dtype=torch.float,
- device=next(iter(
- outputs.values())).device)
- num_instances = reduce_mean(num_instances).clamp_(min=1).item()
- # Compute all the requested losses
- loss_cls = self.loss_classification(outputs, batch_gt_instances,
- indices, num_instances)
- loss_obj, loss_dice, loss_mask = self.loss_masks_with_iou_objectness(
- outputs, batch_gt_instances, indices, num_instances)
- return dict(
- loss_cls=loss_cls,
- loss_obj=loss_obj,
- loss_dice=loss_dice,
- loss_mask=loss_mask)
- @TASK_UTILS.register_module()
- class SparseInstMatcher(nn.Module):
- def __init__(self, alpha=0.8, beta=0.2):
- super().__init__()
- self.alpha = alpha
- self.beta = beta
- self.mask_score = dice_score
- def forward(self, outputs, batch_gt_instances):
- with torch.no_grad():
- B, N, H, W = outputs['pred_masks'].shape
- pred_masks = outputs['pred_masks']
- pred_logits = outputs['pred_logits'].sigmoid()
- device = pred_masks.device
- tgt_ids = torch.cat([gt.labels for gt in batch_gt_instances])
- if tgt_ids.shape[0] == 0:
- return [(torch.as_tensor([]).to(pred_logits),
- torch.as_tensor([]).to(pred_logits))] * B
- tgt_masks = torch.cat([
- gt.masks.to_tensor(dtype=pred_masks.dtype, device=device)
- for gt in batch_gt_instances
- ])
- tgt_masks = F.interpolate(
- tgt_masks[:, None],
- size=pred_masks.shape[-2:],
- mode='bilinear',
- align_corners=False).squeeze(1)
- pred_masks = pred_masks.view(B * N, -1)
- tgt_masks = tgt_masks.flatten(1)
- with autocast(enabled=False):
- pred_masks = pred_masks.float()
- tgt_masks = tgt_masks.float()
- pred_logits = pred_logits.float()
- mask_score = self.mask_score(pred_masks, tgt_masks)
- # Nx(Number of gts)
- matching_prob = pred_logits.view(B * N, -1)[:, tgt_ids]
- C = (mask_score**self.alpha) * (matching_prob**self.beta)
- C = C.view(B, N, -1).cpu()
- # hungarian matching
- sizes = [len(gt.masks) for gt in batch_gt_instances]
- indices = [
- linear_sum_assignment(c[i], maximize=True)
- for i, c in enumerate(C.split(sizes, -1))
- ]
- indices = [(torch.as_tensor(i, dtype=torch.int64),
- torch.as_tensor(j, dtype=torch.int64))
- for i, j in indices]
- return indices
|