loss.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # Copyright (c) Tianheng Cheng and its affiliates. All Rights Reserved
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from scipy.optimize import linear_sum_assignment
  6. from torch.cuda.amp import autocast
  7. from mmdet.registry import MODELS, TASK_UTILS
  8. from mmdet.utils import reduce_mean
  9. def compute_mask_iou(inputs, targets):
  10. inputs = inputs.sigmoid()
  11. # thresholding
  12. binarized_inputs = (inputs >= 0.4).float()
  13. targets = (targets > 0.5).float()
  14. intersection = (binarized_inputs * targets).sum(-1)
  15. union = targets.sum(-1) + binarized_inputs.sum(-1) - intersection
  16. score = intersection / (union + 1e-6)
  17. return score
  18. def dice_score(inputs, targets):
  19. inputs = inputs.sigmoid()
  20. numerator = 2 * torch.matmul(inputs, targets.t())
  21. denominator = (inputs * inputs).sum(-1)[:,
  22. None] + (targets * targets).sum(-1)
  23. score = numerator / (denominator + 1e-4)
  24. return score
  25. @MODELS.register_module()
  26. class SparseInstCriterion(nn.Module):
  27. """This part is partially derivated from:
  28. https://github.com/facebookresearch/detr/blob/main/models/detr.py.
  29. """
  30. def __init__(
  31. self,
  32. num_classes,
  33. assigner,
  34. loss_cls=dict(
  35. type='FocalLoss',
  36. use_sigmoid=True,
  37. alpha=0.25,
  38. gamma=2.0,
  39. reduction='sum',
  40. loss_weight=2.0),
  41. loss_obj=dict(
  42. type='CrossEntropyLoss',
  43. use_sigmoid=True,
  44. reduction='mean',
  45. loss_weight=1.0),
  46. loss_mask=dict(
  47. type='CrossEntropyLoss',
  48. use_sigmoid=True,
  49. reduction='mean',
  50. loss_weight=5.0),
  51. loss_dice=dict(
  52. type='DiceLoss',
  53. use_sigmoid=True,
  54. reduction='sum',
  55. eps=5e-5,
  56. loss_weight=2.0),
  57. ):
  58. super().__init__()
  59. self.matcher = TASK_UTILS.build(assigner)
  60. self.num_classes = num_classes
  61. self.loss_cls = MODELS.build(loss_cls)
  62. self.loss_obj = MODELS.build(loss_obj)
  63. self.loss_mask = MODELS.build(loss_mask)
  64. self.loss_dice = MODELS.build(loss_dice)
  65. def _get_src_permutation_idx(self, indices):
  66. # permute predictions following indices
  67. batch_idx = torch.cat(
  68. [torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
  69. src_idx = torch.cat([src for (src, _) in indices])
  70. return batch_idx, src_idx
  71. def _get_tgt_permutation_idx(self, indices):
  72. # permute targets following indices
  73. batch_idx = torch.cat(
  74. [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
  75. tgt_idx = torch.cat([tgt for (_, tgt) in indices])
  76. return batch_idx, tgt_idx
  77. def loss_classification(self, outputs, batch_gt_instances, indices,
  78. num_instances):
  79. assert 'pred_logits' in outputs
  80. src_logits = outputs['pred_logits']
  81. idx = self._get_src_permutation_idx(indices)
  82. target_classes_o = torch.cat(
  83. [gt.labels[J] for gt, (_, J) in zip(batch_gt_instances, indices)])
  84. target_classes = torch.full(
  85. src_logits.shape[:2],
  86. self.num_classes,
  87. dtype=torch.int64,
  88. device=src_logits.device)
  89. target_classes[idx] = target_classes_o
  90. src_logits = src_logits.flatten(0, 1)
  91. target_classes = target_classes.flatten(0, 1)
  92. # comp focal loss.
  93. class_loss = self.loss_cls(
  94. src_logits,
  95. target_classes,
  96. ) / num_instances
  97. return class_loss
  98. def loss_masks_with_iou_objectness(self, outputs, batch_gt_instances,
  99. indices, num_instances):
  100. src_idx = self._get_src_permutation_idx(indices)
  101. tgt_idx = self._get_tgt_permutation_idx(indices)
  102. # Bx100xHxW
  103. assert 'pred_masks' in outputs
  104. assert 'pred_scores' in outputs
  105. src_iou_scores = outputs['pred_scores']
  106. src_masks = outputs['pred_masks']
  107. with torch.no_grad():
  108. target_masks = torch.cat([
  109. gt.masks.to_tensor(
  110. dtype=src_masks.dtype, device=src_masks.device)
  111. for gt in batch_gt_instances
  112. ])
  113. num_masks = [len(gt.masks) for gt in batch_gt_instances]
  114. target_masks = target_masks.to(src_masks)
  115. if len(target_masks) == 0:
  116. loss_dice = src_masks.sum() * 0.0
  117. loss_mask = src_masks.sum() * 0.0
  118. loss_objectness = src_iou_scores.sum() * 0.0
  119. return loss_objectness, loss_dice, loss_mask
  120. src_masks = src_masks[src_idx]
  121. target_masks = F.interpolate(
  122. target_masks[:, None],
  123. size=src_masks.shape[-2:],
  124. mode='bilinear',
  125. align_corners=False).squeeze(1)
  126. src_masks = src_masks.flatten(1)
  127. # FIXME: tgt_idx
  128. mix_tgt_idx = torch.zeros_like(tgt_idx[1])
  129. cum_sum = 0
  130. for num_mask in num_masks:
  131. mix_tgt_idx[cum_sum:cum_sum + num_mask] = cum_sum
  132. cum_sum += num_mask
  133. mix_tgt_idx += tgt_idx[1]
  134. target_masks = target_masks[mix_tgt_idx].flatten(1)
  135. with torch.no_grad():
  136. ious = compute_mask_iou(src_masks, target_masks)
  137. tgt_iou_scores = ious
  138. src_iou_scores = src_iou_scores[src_idx]
  139. tgt_iou_scores = tgt_iou_scores.flatten(0)
  140. src_iou_scores = src_iou_scores.flatten(0)
  141. loss_objectness = self.loss_obj(src_iou_scores, tgt_iou_scores)
  142. loss_dice = self.loss_dice(src_masks, target_masks) / num_instances
  143. loss_mask = self.loss_mask(src_masks, target_masks)
  144. return loss_objectness, loss_dice, loss_mask
  145. def forward(self, outputs, batch_gt_instances, batch_img_metas,
  146. batch_gt_instances_ignore):
  147. # Retrieve the matching between the outputs of
  148. # the last layer and the targets
  149. indices = self.matcher(outputs, batch_gt_instances)
  150. # Compute the average number of target boxes
  151. # across all nodes, for normalization purposes
  152. num_instances = sum(gt.labels.shape[0] for gt in batch_gt_instances)
  153. num_instances = torch.as_tensor([num_instances],
  154. dtype=torch.float,
  155. device=next(iter(
  156. outputs.values())).device)
  157. num_instances = reduce_mean(num_instances).clamp_(min=1).item()
  158. # Compute all the requested losses
  159. loss_cls = self.loss_classification(outputs, batch_gt_instances,
  160. indices, num_instances)
  161. loss_obj, loss_dice, loss_mask = self.loss_masks_with_iou_objectness(
  162. outputs, batch_gt_instances, indices, num_instances)
  163. return dict(
  164. loss_cls=loss_cls,
  165. loss_obj=loss_obj,
  166. loss_dice=loss_dice,
  167. loss_mask=loss_mask)
  168. @TASK_UTILS.register_module()
  169. class SparseInstMatcher(nn.Module):
  170. def __init__(self, alpha=0.8, beta=0.2):
  171. super().__init__()
  172. self.alpha = alpha
  173. self.beta = beta
  174. self.mask_score = dice_score
  175. def forward(self, outputs, batch_gt_instances):
  176. with torch.no_grad():
  177. B, N, H, W = outputs['pred_masks'].shape
  178. pred_masks = outputs['pred_masks']
  179. pred_logits = outputs['pred_logits'].sigmoid()
  180. device = pred_masks.device
  181. tgt_ids = torch.cat([gt.labels for gt in batch_gt_instances])
  182. if tgt_ids.shape[0] == 0:
  183. return [(torch.as_tensor([]).to(pred_logits),
  184. torch.as_tensor([]).to(pred_logits))] * B
  185. tgt_masks = torch.cat([
  186. gt.masks.to_tensor(dtype=pred_masks.dtype, device=device)
  187. for gt in batch_gt_instances
  188. ])
  189. tgt_masks = F.interpolate(
  190. tgt_masks[:, None],
  191. size=pred_masks.shape[-2:],
  192. mode='bilinear',
  193. align_corners=False).squeeze(1)
  194. pred_masks = pred_masks.view(B * N, -1)
  195. tgt_masks = tgt_masks.flatten(1)
  196. with autocast(enabled=False):
  197. pred_masks = pred_masks.float()
  198. tgt_masks = tgt_masks.float()
  199. pred_logits = pred_logits.float()
  200. mask_score = self.mask_score(pred_masks, tgt_masks)
  201. # Nx(Number of gts)
  202. matching_prob = pred_logits.view(B * N, -1)[:, tgt_ids]
  203. C = (mask_score**self.alpha) * (matching_prob**self.beta)
  204. C = C.view(B, N, -1).cpu()
  205. # hungarian matching
  206. sizes = [len(gt.masks) for gt in batch_gt_instances]
  207. indices = [
  208. linear_sum_assignment(c[i], maximize=True)
  209. for i, c in enumerate(C.split(sizes, -1))
  210. ]
  211. indices = [(torch.as_tensor(i, dtype=torch.int64),
  212. torch.as_tensor(j, dtype=torch.int64))
  213. for i, j in indices]
  214. return indices