loss.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  3. # Modified from https://github.com/ShoufaChen/DiffusionDet/blob/main/diffusiondet/loss.py # noqa
  4. # This work is licensed under the CC-BY-NC 4.0 License.
  5. # Users should be careful about adopting these features in any commercial matters. # noqa
  6. # For more details, please refer to https://github.com/ShoufaChen/DiffusionDet/blob/main/LICENSE # noqa
  7. from typing import List, Tuple, Union
  8. import torch
  9. import torch.nn as nn
  10. from mmengine.config import ConfigDict
  11. from mmengine.structures import InstanceData
  12. from torch import Tensor
  13. from mmdet.registry import MODELS, TASK_UTILS
  14. from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh
  15. from mmdet.utils import ConfigType
  16. @TASK_UTILS.register_module()
  17. class DiffusionDetCriterion(nn.Module):
  18. def __init__(
  19. self,
  20. num_classes,
  21. assigner: Union[ConfigDict, nn.Module],
  22. deep_supervision=True,
  23. loss_cls=dict(
  24. type='FocalLoss',
  25. use_sigmoid=True,
  26. alpha=0.25,
  27. gamma=2.0,
  28. reduction='sum',
  29. loss_weight=2.0),
  30. loss_bbox=dict(type='L1Loss', reduction='sum', loss_weight=5.0),
  31. loss_giou=dict(type='GIoULoss', reduction='sum', loss_weight=2.0),
  32. ):
  33. super().__init__()
  34. self.num_classes = num_classes
  35. if isinstance(assigner, nn.Module):
  36. self.assigner = assigner
  37. else:
  38. self.assigner = TASK_UTILS.build(assigner)
  39. self.deep_supervision = deep_supervision
  40. self.loss_cls = MODELS.build(loss_cls)
  41. self.loss_bbox = MODELS.build(loss_bbox)
  42. self.loss_giou = MODELS.build(loss_giou)
  43. def forward(self, outputs, batch_gt_instances, batch_img_metas):
  44. batch_indices = self.assigner(outputs, batch_gt_instances,
  45. batch_img_metas)
  46. # Compute all the requested losses
  47. loss_cls = self.loss_classification(outputs, batch_gt_instances,
  48. batch_indices)
  49. loss_bbox, loss_giou = self.loss_boxes(outputs, batch_gt_instances,
  50. batch_indices)
  51. losses = dict(
  52. loss_cls=loss_cls, loss_bbox=loss_bbox, loss_giou=loss_giou)
  53. if self.deep_supervision:
  54. assert 'aux_outputs' in outputs
  55. for i, aux_outputs in enumerate(outputs['aux_outputs']):
  56. batch_indices = self.assigner(aux_outputs, batch_gt_instances,
  57. batch_img_metas)
  58. loss_cls = self.loss_classification(aux_outputs,
  59. batch_gt_instances,
  60. batch_indices)
  61. loss_bbox, loss_giou = self.loss_boxes(aux_outputs,
  62. batch_gt_instances,
  63. batch_indices)
  64. tmp_losses = dict(
  65. loss_cls=loss_cls,
  66. loss_bbox=loss_bbox,
  67. loss_giou=loss_giou)
  68. for name, value in tmp_losses.items():
  69. losses[f's.{i}.{name}'] = value
  70. return losses
  71. def loss_classification(self, outputs, batch_gt_instances, indices):
  72. assert 'pred_logits' in outputs
  73. src_logits = outputs['pred_logits']
  74. target_classes_list = [
  75. gt.labels[J] for gt, (_, J) in zip(batch_gt_instances, indices)
  76. ]
  77. target_classes = torch.full(
  78. src_logits.shape[:2],
  79. self.num_classes,
  80. dtype=torch.int64,
  81. device=src_logits.device)
  82. for idx in range(len(batch_gt_instances)):
  83. target_classes[idx, indices[idx][0]] = target_classes_list[idx]
  84. src_logits = src_logits.flatten(0, 1)
  85. target_classes = target_classes.flatten(0, 1)
  86. # comp focal loss.
  87. num_instances = max(torch.cat(target_classes_list).shape[0], 1)
  88. loss_cls = self.loss_cls(
  89. src_logits,
  90. target_classes,
  91. ) / num_instances
  92. return loss_cls
  93. def loss_boxes(self, outputs, batch_gt_instances, indices):
  94. assert 'pred_boxes' in outputs
  95. pred_boxes = outputs['pred_boxes']
  96. target_bboxes_norm_list = [
  97. gt.norm_bboxes_cxcywh[J]
  98. for gt, (_, J) in zip(batch_gt_instances, indices)
  99. ]
  100. target_bboxes_list = [
  101. gt.bboxes[J] for gt, (_, J) in zip(batch_gt_instances, indices)
  102. ]
  103. pred_bboxes_list = []
  104. pred_bboxes_norm_list = []
  105. for idx in range(len(batch_gt_instances)):
  106. pred_bboxes_list.append(pred_boxes[idx, indices[idx][0]])
  107. image_size = batch_gt_instances[idx].image_size
  108. pred_bboxes_norm_list.append(pred_boxes[idx, indices[idx][0]] /
  109. image_size)
  110. pred_boxes_cat = torch.cat(pred_bboxes_list)
  111. pred_boxes_norm_cat = torch.cat(pred_bboxes_norm_list)
  112. target_bboxes_cat = torch.cat(target_bboxes_list)
  113. target_bboxes_norm_cat = torch.cat(target_bboxes_norm_list)
  114. if len(pred_boxes_cat) > 0:
  115. num_instances = pred_boxes_cat.shape[0]
  116. loss_bbox = self.loss_bbox(
  117. pred_boxes_norm_cat,
  118. bbox_cxcywh_to_xyxy(target_bboxes_norm_cat)) / num_instances
  119. loss_giou = self.loss_giou(pred_boxes_cat,
  120. target_bboxes_cat) / num_instances
  121. else:
  122. loss_bbox = pred_boxes.sum() * 0
  123. loss_giou = pred_boxes.sum() * 0
  124. return loss_bbox, loss_giou
  125. @TASK_UTILS.register_module()
  126. class DiffusionDetMatcher(nn.Module):
  127. """This class computes an assignment between the targets and the
  128. predictions of the network For efficiency reasons, the targets don't
  129. include the no_object.
  130. Because of this, in general, there are more predictions than targets. In
  131. this case, we do a 1-to-k (dynamic) matching of the best predictions, while
  132. the others are un-matched (and thus treated as non-objects).
  133. """
  134. def __init__(self,
  135. match_costs: Union[List[Union[dict, ConfigDict]], dict,
  136. ConfigDict],
  137. center_radius: float = 2.5,
  138. candidate_topk: int = 5,
  139. iou_calculator: ConfigType = dict(type='BboxOverlaps2D'),
  140. **kwargs):
  141. super().__init__()
  142. self.center_radius = center_radius
  143. self.candidate_topk = candidate_topk
  144. if isinstance(match_costs, dict):
  145. match_costs = [match_costs]
  146. elif isinstance(match_costs, list):
  147. assert len(match_costs) > 0, \
  148. 'match_costs must not be a empty list.'
  149. self.use_focal_loss = False
  150. self.use_fed_loss = False
  151. for _match_cost in match_costs:
  152. if _match_cost.get('type') == 'FocalLossCost':
  153. self.use_focal_loss = True
  154. if _match_cost.get('type') == 'FedLoss':
  155. self.use_fed_loss = True
  156. raise NotImplementedError
  157. self.match_costs = [
  158. TASK_UTILS.build(match_cost) for match_cost in match_costs
  159. ]
  160. self.iou_calculator = TASK_UTILS.build(iou_calculator)
  161. def forward(self, outputs, batch_gt_instances, batch_img_metas):
  162. assert 'pred_logits' in outputs and 'pred_boxes' in outputs
  163. pred_logits = outputs['pred_logits']
  164. pred_bboxes = outputs['pred_boxes']
  165. batch_size = len(batch_gt_instances)
  166. assert batch_size == pred_logits.shape[0] == pred_bboxes.shape[0]
  167. batch_indices = []
  168. for i in range(batch_size):
  169. pred_instances = InstanceData()
  170. pred_instances.bboxes = pred_bboxes[i, ...]
  171. pred_instances.scores = pred_logits[i, ...]
  172. gt_instances = batch_gt_instances[i]
  173. img_meta = batch_img_metas[i]
  174. indices = self.single_assigner(pred_instances, gt_instances,
  175. img_meta)
  176. batch_indices.append(indices)
  177. return batch_indices
  178. def single_assigner(self, pred_instances, gt_instances, img_meta):
  179. with torch.no_grad():
  180. gt_bboxes = gt_instances.bboxes
  181. pred_bboxes = pred_instances.bboxes
  182. num_gt = gt_bboxes.size(0)
  183. if num_gt == 0: # empty object in key frame
  184. valid_mask = pred_bboxes.new_zeros((pred_bboxes.shape[0], ),
  185. dtype=torch.bool)
  186. matched_gt_inds = pred_bboxes.new_zeros((gt_bboxes.shape[0], ),
  187. dtype=torch.long)
  188. return valid_mask, matched_gt_inds
  189. valid_mask, is_in_boxes_and_center = \
  190. self.get_in_gt_and_in_center_info(
  191. bbox_xyxy_to_cxcywh(pred_bboxes),
  192. bbox_xyxy_to_cxcywh(gt_bboxes)
  193. )
  194. cost_list = []
  195. for match_cost in self.match_costs:
  196. cost = match_cost(
  197. pred_instances=pred_instances,
  198. gt_instances=gt_instances,
  199. img_meta=img_meta)
  200. cost_list.append(cost)
  201. pairwise_ious = self.iou_calculator(pred_bboxes, gt_bboxes)
  202. cost_list.append((~is_in_boxes_and_center) * 100.0)
  203. cost_matrix = torch.stack(cost_list).sum(0)
  204. cost_matrix[~valid_mask] = cost_matrix[~valid_mask] + 10000.0
  205. fg_mask_inboxes, matched_gt_inds = \
  206. self.dynamic_k_matching(
  207. cost_matrix, pairwise_ious, num_gt)
  208. return fg_mask_inboxes, matched_gt_inds
  209. def get_in_gt_and_in_center_info(
  210. self, pred_bboxes: Tensor,
  211. gt_bboxes: Tensor) -> Tuple[Tensor, Tensor]:
  212. """Get the information of which prior is in gt bboxes and gt center
  213. priors."""
  214. xy_target_gts = bbox_cxcywh_to_xyxy(gt_bboxes) # (x1, y1, x2, y2)
  215. pred_bboxes_center_x = pred_bboxes[:, 0].unsqueeze(1)
  216. pred_bboxes_center_y = pred_bboxes[:, 1].unsqueeze(1)
  217. # whether the center of each anchor is inside a gt box
  218. b_l = pred_bboxes_center_x > xy_target_gts[:, 0].unsqueeze(0)
  219. b_r = pred_bboxes_center_x < xy_target_gts[:, 2].unsqueeze(0)
  220. b_t = pred_bboxes_center_y > xy_target_gts[:, 1].unsqueeze(0)
  221. b_b = pred_bboxes_center_y < xy_target_gts[:, 3].unsqueeze(0)
  222. # (b_l.long()+b_r.long()+b_t.long()+b_b.long())==4 [300,num_gt] ,
  223. is_in_boxes = ((b_l.long() + b_r.long() + b_t.long() +
  224. b_b.long()) == 4)
  225. is_in_boxes_all = is_in_boxes.sum(1) > 0 # [num_query]
  226. # in fixed center
  227. center_radius = 2.5
  228. # Modified to self-adapted sampling --- the center size depends
  229. # on the size of the gt boxes
  230. # https://github.com/dulucas/UVO_Challenge/blob/main/Track1/detection/mmdet/core/bbox/assigners/rpn_sim_ota_assigner.py#L212 # noqa
  231. b_l = pred_bboxes_center_x > (
  232. gt_bboxes[:, 0] -
  233. (center_radius *
  234. (xy_target_gts[:, 2] - xy_target_gts[:, 0]))).unsqueeze(0)
  235. b_r = pred_bboxes_center_x < (
  236. gt_bboxes[:, 0] +
  237. (center_radius *
  238. (xy_target_gts[:, 2] - xy_target_gts[:, 0]))).unsqueeze(0)
  239. b_t = pred_bboxes_center_y > (
  240. gt_bboxes[:, 1] -
  241. (center_radius *
  242. (xy_target_gts[:, 3] - xy_target_gts[:, 1]))).unsqueeze(0)
  243. b_b = pred_bboxes_center_y < (
  244. gt_bboxes[:, 1] +
  245. (center_radius *
  246. (xy_target_gts[:, 3] - xy_target_gts[:, 1]))).unsqueeze(0)
  247. is_in_centers = ((b_l.long() + b_r.long() + b_t.long() +
  248. b_b.long()) == 4)
  249. is_in_centers_all = is_in_centers.sum(1) > 0
  250. is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
  251. is_in_boxes_and_center = (is_in_boxes & is_in_centers)
  252. return is_in_boxes_anchor, is_in_boxes_and_center
  253. def dynamic_k_matching(self, cost: Tensor, pairwise_ious: Tensor,
  254. num_gt: int) -> Tuple[Tensor, Tensor]:
  255. """Use IoU and matching cost to calculate the dynamic top-k positive
  256. targets."""
  257. matching_matrix = torch.zeros_like(cost)
  258. # select candidate topk ious for dynamic-k calculation
  259. candidate_topk = min(self.candidate_topk, pairwise_ious.size(0))
  260. topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0)
  261. # calculate dynamic k for each gt
  262. dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
  263. for gt_idx in range(num_gt):
  264. _, pos_idx = torch.topk(
  265. cost[:, gt_idx], k=dynamic_ks[gt_idx], largest=False)
  266. matching_matrix[:, gt_idx][pos_idx] = 1
  267. del topk_ious, dynamic_ks, pos_idx
  268. prior_match_gt_mask = matching_matrix.sum(1) > 1
  269. if prior_match_gt_mask.sum() > 0:
  270. _, cost_argmin = torch.min(cost[prior_match_gt_mask, :], dim=1)
  271. matching_matrix[prior_match_gt_mask, :] *= 0
  272. matching_matrix[prior_match_gt_mask, cost_argmin] = 1
  273. while (matching_matrix.sum(0) == 0).any():
  274. matched_query_id = matching_matrix.sum(1) > 0
  275. cost[matched_query_id] += 100000.0
  276. unmatch_id = torch.nonzero(
  277. matching_matrix.sum(0) == 0, as_tuple=False).squeeze(1)
  278. for gt_idx in unmatch_id:
  279. pos_idx = torch.argmin(cost[:, gt_idx])
  280. matching_matrix[:, gt_idx][pos_idx] = 1.0
  281. if (matching_matrix.sum(1) > 1).sum() > 0:
  282. _, cost_argmin = torch.min(cost[prior_match_gt_mask], dim=1)
  283. matching_matrix[prior_match_gt_mask] *= 0
  284. matching_matrix[prior_match_gt_mask, cost_argmin, ] = 1
  285. assert not (matching_matrix.sum(0) == 0).any()
  286. # get foreground mask inside box and center prior
  287. fg_mask_inboxes = matching_matrix.sum(1) > 0
  288. matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)
  289. return fg_mask_inboxes, matched_gt_inds