dino_head.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, List, Tuple
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from torch import Tensor
  6. from mmdet.registry import MODELS
  7. from mmdet.structures import SampleList
  8. from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh
  9. from mmdet.utils import InstanceList, OptInstanceList, reduce_mean
  10. from ..utils import multi_apply
  11. from .deformable_detr_head import DeformableDETRHead
  12. @MODELS.register_module()
  13. class DINOHead(DeformableDETRHead):
  14. r"""Head of the DINO: DETR with Improved DeNoising Anchor Boxes
  15. for End-to-End Object Detection
  16. Code is modified from the `official github repo
  17. <https://github.com/IDEA-Research/DINO>`_.
  18. More details can be found in the `paper
  19. <https://arxiv.org/abs/2203.03605>`_ .
  20. """
  21. def loss(self, hidden_states: Tensor, references: List[Tensor],
  22. enc_outputs_class: Tensor, enc_outputs_coord: Tensor,
  23. batch_data_samples: SampleList, dn_meta: Dict[str, int]) -> dict:
  24. """Perform forward propagation and loss calculation of the detection
  25. head on the queries of the upstream network.
  26. Args:
  27. hidden_states (Tensor): Hidden states output from each decoder
  28. layer, has shape (num_decoder_layers, bs, num_queries_total,
  29. dim), where `num_queries_total` is the sum of
  30. `num_denoising_queries` and `num_matching_queries` when
  31. `self.training` is `True`, else `num_matching_queries`.
  32. references (list[Tensor]): List of the reference from the decoder.
  33. The first reference is the `init_reference` (initial) and the
  34. other num_decoder_layers(6) references are `inter_references`
  35. (intermediate). The `init_reference` has shape (bs,
  36. num_queries_total, 4) and each `inter_reference` has shape
  37. (bs, num_queries, 4) with the last dimension arranged as
  38. (cx, cy, w, h).
  39. enc_outputs_class (Tensor): The score of each point on encode
  40. feature map, has shape (bs, num_feat_points, cls_out_channels).
  41. enc_outputs_coord (Tensor): The proposal generate from the
  42. encode feature map, has shape (bs, num_feat_points, 4) with the
  43. last dimension arranged as (cx, cy, w, h).
  44. batch_data_samples (list[:obj:`DetDataSample`]): The Data
  45. Samples. It usually includes information such as
  46. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  47. dn_meta (Dict[str, int]): The dictionary saves information about
  48. group collation, including 'num_denoising_queries' and
  49. 'num_denoising_groups'. It will be used for split outputs of
  50. denoising and matching parts and loss calculation.
  51. Returns:
  52. dict: A dictionary of loss components.
  53. """
  54. batch_gt_instances = []
  55. batch_img_metas = []
  56. for data_sample in batch_data_samples:
  57. batch_img_metas.append(data_sample.metainfo)
  58. batch_gt_instances.append(data_sample.gt_instances)
  59. outs = self(hidden_states, references)
  60. loss_inputs = outs + (enc_outputs_class, enc_outputs_coord,
  61. batch_gt_instances, batch_img_metas, dn_meta)
  62. losses = self.loss_by_feat(*loss_inputs)
  63. return losses
  64. def loss_by_feat(
  65. self,
  66. all_layers_cls_scores: Tensor,
  67. all_layers_bbox_preds: Tensor,
  68. enc_cls_scores: Tensor,
  69. enc_bbox_preds: Tensor,
  70. batch_gt_instances: InstanceList,
  71. batch_img_metas: List[dict],
  72. dn_meta: Dict[str, int],
  73. batch_gt_instances_ignore: OptInstanceList = None
  74. ) -> Dict[str, Tensor]:
  75. """Loss function.
  76. Args:
  77. all_layers_cls_scores (Tensor): Classification scores of all
  78. decoder layers, has shape (num_decoder_layers, bs,
  79. num_queries_total, cls_out_channels), where
  80. `num_queries_total` is the sum of `num_denoising_queries`
  81. and `num_matching_queries`.
  82. all_layers_bbox_preds (Tensor): Regression outputs of all decoder
  83. layers. Each is a 4D-tensor with normalized coordinate format
  84. (cx, cy, w, h) and has shape (num_decoder_layers, bs,
  85. num_queries_total, 4).
  86. enc_cls_scores (Tensor): The score of each point on encode
  87. feature map, has shape (bs, num_feat_points, cls_out_channels).
  88. enc_bbox_preds (Tensor): The proposal generate from the encode
  89. feature map, has shape (bs, num_feat_points, 4) with the last
  90. dimension arranged as (cx, cy, w, h).
  91. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  92. gt_instance. It usually includes ``bboxes`` and ``labels``
  93. attributes.
  94. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  95. image size, scaling factor, etc.
  96. dn_meta (Dict[str, int]): The dictionary saves information about
  97. group collation, including 'num_denoising_queries' and
  98. 'num_denoising_groups'. It will be used for split outputs of
  99. denoising and matching parts and loss calculation.
  100. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  101. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  102. data that is ignored during training and testing.
  103. Defaults to None.
  104. Returns:
  105. dict[str, Tensor]: A dictionary of loss components.
  106. """
  107. # extract denoising and matching part of outputs
  108. (all_layers_matching_cls_scores, all_layers_matching_bbox_preds,
  109. all_layers_denoising_cls_scores, all_layers_denoising_bbox_preds) = \
  110. self.split_outputs(
  111. all_layers_cls_scores, all_layers_bbox_preds, dn_meta)
  112. loss_dict = super(DeformableDETRHead, self).loss_by_feat(
  113. all_layers_matching_cls_scores, all_layers_matching_bbox_preds,
  114. batch_gt_instances, batch_img_metas, batch_gt_instances_ignore)
  115. # NOTE DETRHead.loss_by_feat but not DeformableDETRHead.loss_by_feat
  116. # is called, because the encoder loss calculations are different
  117. # between DINO and DeformableDETR.
  118. # loss of proposal generated from encode feature map.
  119. if enc_cls_scores is not None:
  120. # NOTE The enc_loss calculation of the DINO is
  121. # different from that of Deformable DETR.
  122. enc_loss_cls, enc_losses_bbox, enc_losses_iou = \
  123. self.loss_by_feat_single(
  124. enc_cls_scores, enc_bbox_preds,
  125. batch_gt_instances=batch_gt_instances,
  126. batch_img_metas=batch_img_metas)
  127. loss_dict['enc_loss_cls'] = enc_loss_cls
  128. loss_dict['enc_loss_bbox'] = enc_losses_bbox
  129. loss_dict['enc_loss_iou'] = enc_losses_iou
  130. if all_layers_denoising_cls_scores is not None:
  131. # calculate denoising loss from all decoder layers
  132. dn_losses_cls, dn_losses_bbox, dn_losses_iou = self.loss_dn(
  133. all_layers_denoising_cls_scores,
  134. all_layers_denoising_bbox_preds,
  135. batch_gt_instances=batch_gt_instances,
  136. batch_img_metas=batch_img_metas,
  137. dn_meta=dn_meta)
  138. # collate denoising loss
  139. loss_dict['dn_loss_cls'] = dn_losses_cls[-1]
  140. loss_dict['dn_loss_bbox'] = dn_losses_bbox[-1]
  141. loss_dict['dn_loss_iou'] = dn_losses_iou[-1]
  142. for num_dec_layer, (loss_cls_i, loss_bbox_i, loss_iou_i) in \
  143. enumerate(zip(dn_losses_cls[:-1], dn_losses_bbox[:-1],
  144. dn_losses_iou[:-1])):
  145. loss_dict[f'd{num_dec_layer}.dn_loss_cls'] = loss_cls_i
  146. loss_dict[f'd{num_dec_layer}.dn_loss_bbox'] = loss_bbox_i
  147. loss_dict[f'd{num_dec_layer}.dn_loss_iou'] = loss_iou_i
  148. return loss_dict
  149. def loss_dn(self, all_layers_denoising_cls_scores: Tensor,
  150. all_layers_denoising_bbox_preds: Tensor,
  151. batch_gt_instances: InstanceList, batch_img_metas: List[dict],
  152. dn_meta: Dict[str, int]) -> Tuple[List[Tensor]]:
  153. """Calculate denoising loss.
  154. Args:
  155. all_layers_denoising_cls_scores (Tensor): Classification scores of
  156. all decoder layers in denoising part, has shape (
  157. num_decoder_layers, bs, num_denoising_queries,
  158. cls_out_channels).
  159. all_layers_denoising_bbox_preds (Tensor): Regression outputs of all
  160. decoder layers in denoising part. Each is a 4D-tensor with
  161. normalized coordinate format (cx, cy, w, h) and has shape
  162. (num_decoder_layers, bs, num_denoising_queries, 4).
  163. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  164. gt_instance. It usually includes ``bboxes`` and ``labels``
  165. attributes.
  166. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  167. image size, scaling factor, etc.
  168. dn_meta (Dict[str, int]): The dictionary saves information about
  169. group collation, including 'num_denoising_queries' and
  170. 'num_denoising_groups'. It will be used for split outputs of
  171. denoising and matching parts and loss calculation.
  172. Returns:
  173. Tuple[List[Tensor]]: The loss_dn_cls, loss_dn_bbox, and loss_dn_iou
  174. of each decoder layers.
  175. """
  176. return multi_apply(
  177. self._loss_dn_single,
  178. all_layers_denoising_cls_scores,
  179. all_layers_denoising_bbox_preds,
  180. batch_gt_instances=batch_gt_instances,
  181. batch_img_metas=batch_img_metas,
  182. dn_meta=dn_meta)
  183. def _loss_dn_single(self, dn_cls_scores: Tensor, dn_bbox_preds: Tensor,
  184. batch_gt_instances: InstanceList,
  185. batch_img_metas: List[dict],
  186. dn_meta: Dict[str, int]) -> Tuple[Tensor]:
  187. """Denoising loss for outputs from a single decoder layer.
  188. Args:
  189. dn_cls_scores (Tensor): Classification scores of a single decoder
  190. layer in denoising part, has shape (bs, num_denoising_queries,
  191. cls_out_channels).
  192. dn_bbox_preds (Tensor): Regression outputs of a single decoder
  193. layer in denoising part. Each is a 4D-tensor with normalized
  194. coordinate format (cx, cy, w, h) and has shape
  195. (bs, num_denoising_queries, 4).
  196. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  197. gt_instance. It usually includes ``bboxes`` and ``labels``
  198. attributes.
  199. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  200. image size, scaling factor, etc.
  201. dn_meta (Dict[str, int]): The dictionary saves information about
  202. group collation, including 'num_denoising_queries' and
  203. 'num_denoising_groups'. It will be used for split outputs of
  204. denoising and matching parts and loss calculation.
  205. Returns:
  206. Tuple[Tensor]: A tuple including `loss_cls`, `loss_box` and
  207. `loss_iou`.
  208. """
  209. cls_reg_targets = self.get_dn_targets(batch_gt_instances,
  210. batch_img_metas, dn_meta)
  211. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  212. num_total_pos, num_total_neg) = cls_reg_targets
  213. labels = torch.cat(labels_list, 0)
  214. label_weights = torch.cat(label_weights_list, 0)
  215. bbox_targets = torch.cat(bbox_targets_list, 0)
  216. bbox_weights = torch.cat(bbox_weights_list, 0)
  217. # classification loss
  218. cls_scores = dn_cls_scores.reshape(-1, self.cls_out_channels)
  219. # construct weighted avg_factor to match with the official DETR repo
  220. cls_avg_factor = \
  221. num_total_pos * 1.0 + num_total_neg * self.bg_cls_weight
  222. if self.sync_cls_avg_factor:
  223. cls_avg_factor = reduce_mean(
  224. cls_scores.new_tensor([cls_avg_factor]))
  225. cls_avg_factor = max(cls_avg_factor, 1)
  226. if len(cls_scores) > 0:
  227. loss_cls = self.loss_cls(
  228. cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
  229. else:
  230. loss_cls = torch.zeros(
  231. 1, dtype=cls_scores.dtype, device=cls_scores.device)
  232. # Compute the average number of gt boxes across all gpus, for
  233. # normalization purposes
  234. num_total_pos = loss_cls.new_tensor([num_total_pos])
  235. num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
  236. # construct factors used for rescale bboxes
  237. factors = []
  238. for img_meta, bbox_pred in zip(batch_img_metas, dn_bbox_preds):
  239. img_h, img_w = img_meta['img_shape']
  240. factor = bbox_pred.new_tensor([img_w, img_h, img_w,
  241. img_h]).unsqueeze(0).repeat(
  242. bbox_pred.size(0), 1)
  243. factors.append(factor)
  244. factors = torch.cat(factors)
  245. # DETR regress the relative position of boxes (cxcywh) in the image,
  246. # thus the learning target is normalized by the image size. So here
  247. # we need to re-scale them for calculating IoU loss
  248. bbox_preds = dn_bbox_preds.reshape(-1, 4)
  249. bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
  250. bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors
  251. # regression IoU loss, defaultly GIoU loss
  252. loss_iou = self.loss_iou(
  253. bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos)
  254. # regression L1 loss
  255. loss_bbox = self.loss_bbox(
  256. bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos)
  257. return loss_cls, loss_bbox, loss_iou
  258. def get_dn_targets(self, batch_gt_instances: InstanceList,
  259. batch_img_metas: dict, dn_meta: Dict[str,
  260. int]) -> tuple:
  261. """Get targets in denoising part for a batch of images.
  262. Args:
  263. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  264. gt_instance. It usually includes ``bboxes`` and ``labels``
  265. attributes.
  266. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  267. image size, scaling factor, etc.
  268. dn_meta (Dict[str, int]): The dictionary saves information about
  269. group collation, including 'num_denoising_queries' and
  270. 'num_denoising_groups'. It will be used for split outputs of
  271. denoising and matching parts and loss calculation.
  272. Returns:
  273. tuple: a tuple containing the following targets.
  274. - labels_list (list[Tensor]): Labels for all images.
  275. - label_weights_list (list[Tensor]): Label weights for all images.
  276. - bbox_targets_list (list[Tensor]): BBox targets for all images.
  277. - bbox_weights_list (list[Tensor]): BBox weights for all images.
  278. - num_total_pos (int): Number of positive samples in all images.
  279. - num_total_neg (int): Number of negative samples in all images.
  280. """
  281. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  282. pos_inds_list, neg_inds_list) = multi_apply(
  283. self._get_dn_targets_single,
  284. batch_gt_instances,
  285. batch_img_metas,
  286. dn_meta=dn_meta)
  287. num_total_pos = sum((inds.numel() for inds in pos_inds_list))
  288. num_total_neg = sum((inds.numel() for inds in neg_inds_list))
  289. return (labels_list, label_weights_list, bbox_targets_list,
  290. bbox_weights_list, num_total_pos, num_total_neg)
  291. def _get_dn_targets_single(self, gt_instances: InstanceData,
  292. img_meta: dict, dn_meta: Dict[str,
  293. int]) -> tuple:
  294. """Get targets in denoising part for one image.
  295. Args:
  296. gt_instances (:obj:`InstanceData`): Ground truth of instance
  297. annotations. It should includes ``bboxes`` and ``labels``
  298. attributes.
  299. img_meta (dict): Meta information for one image.
  300. dn_meta (Dict[str, int]): The dictionary saves information about
  301. group collation, including 'num_denoising_queries' and
  302. 'num_denoising_groups'. It will be used for split outputs of
  303. denoising and matching parts and loss calculation.
  304. Returns:
  305. tuple[Tensor]: a tuple containing the following for one image.
  306. - labels (Tensor): Labels of each image.
  307. - label_weights (Tensor]): Label weights of each image.
  308. - bbox_targets (Tensor): BBox targets of each image.
  309. - bbox_weights (Tensor): BBox weights of each image.
  310. - pos_inds (Tensor): Sampled positive indices for each image.
  311. - neg_inds (Tensor): Sampled negative indices for each image.
  312. """
  313. gt_bboxes = gt_instances.bboxes
  314. gt_labels = gt_instances.labels
  315. num_groups = dn_meta['num_denoising_groups']
  316. num_denoising_queries = dn_meta['num_denoising_queries']
  317. num_queries_each_group = int(num_denoising_queries / num_groups)
  318. device = gt_bboxes.device
  319. if len(gt_labels) > 0:
  320. t = torch.arange(len(gt_labels), dtype=torch.long, device=device)
  321. t = t.unsqueeze(0).repeat(num_groups, 1)
  322. pos_assigned_gt_inds = t.flatten()
  323. pos_inds = torch.arange(
  324. num_groups, dtype=torch.long, device=device)
  325. pos_inds = pos_inds.unsqueeze(1) * num_queries_each_group + t
  326. pos_inds = pos_inds.flatten()
  327. else:
  328. pos_inds = pos_assigned_gt_inds = \
  329. gt_bboxes.new_tensor([], dtype=torch.long)
  330. neg_inds = pos_inds + num_queries_each_group // 2
  331. # label targets
  332. labels = gt_bboxes.new_full((num_denoising_queries, ),
  333. self.num_classes,
  334. dtype=torch.long)
  335. labels[pos_inds] = gt_labels[pos_assigned_gt_inds]
  336. label_weights = gt_bboxes.new_ones(num_denoising_queries)
  337. # bbox targets
  338. bbox_targets = torch.zeros(num_denoising_queries, 4, device=device)
  339. bbox_weights = torch.zeros(num_denoising_queries, 4, device=device)
  340. bbox_weights[pos_inds] = 1.0
  341. img_h, img_w = img_meta['img_shape']
  342. # DETR regress the relative position of boxes (cxcywh) in the image.
  343. # Thus the learning target should be normalized by the image size, also
  344. # the box format should be converted from defaultly x1y1x2y2 to cxcywh.
  345. factor = gt_bboxes.new_tensor([img_w, img_h, img_w,
  346. img_h]).unsqueeze(0)
  347. gt_bboxes_normalized = gt_bboxes / factor
  348. gt_bboxes_targets = bbox_xyxy_to_cxcywh(gt_bboxes_normalized)
  349. bbox_targets[pos_inds] = gt_bboxes_targets.repeat([num_groups, 1])
  350. return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
  351. neg_inds)
  352. @staticmethod
  353. def split_outputs(all_layers_cls_scores: Tensor,
  354. all_layers_bbox_preds: Tensor,
  355. dn_meta: Dict[str, int]) -> Tuple[Tensor]:
  356. """Split outputs of the denoising part and the matching part.
  357. For the total outputs of `num_queries_total` length, the former
  358. `num_denoising_queries` outputs are from denoising queries, and
  359. the rest `num_matching_queries` ones are from matching queries,
  360. where `num_queries_total` is the sum of `num_denoising_queries` and
  361. `num_matching_queries`.
  362. Args:
  363. all_layers_cls_scores (Tensor): Classification scores of all
  364. decoder layers, has shape (num_decoder_layers, bs,
  365. num_queries_total, cls_out_channels).
  366. all_layers_bbox_preds (Tensor): Regression outputs of all decoder
  367. layers. Each is a 4D-tensor with normalized coordinate format
  368. (cx, cy, w, h) and has shape (num_decoder_layers, bs,
  369. num_queries_total, 4).
  370. dn_meta (Dict[str, int]): The dictionary saves information about
  371. group collation, including 'num_denoising_queries' and
  372. 'num_denoising_groups'.
  373. Returns:
  374. Tuple[Tensor]: a tuple containing the following outputs.
  375. - all_layers_matching_cls_scores (Tensor): Classification scores
  376. of all decoder layers in matching part, has shape
  377. (num_decoder_layers, bs, num_matching_queries, cls_out_channels).
  378. - all_layers_matching_bbox_preds (Tensor): Regression outputs of
  379. all decoder layers in matching part. Each is a 4D-tensor with
  380. normalized coordinate format (cx, cy, w, h) and has shape
  381. (num_decoder_layers, bs, num_matching_queries, 4).
  382. - all_layers_denoising_cls_scores (Tensor): Classification scores
  383. of all decoder layers in denoising part, has shape
  384. (num_decoder_layers, bs, num_denoising_queries,
  385. cls_out_channels).
  386. - all_layers_denoising_bbox_preds (Tensor): Regression outputs of
  387. all decoder layers in denoising part. Each is a 4D-tensor with
  388. normalized coordinate format (cx, cy, w, h) and has shape
  389. (num_decoder_layers, bs, num_denoising_queries, 4).
  390. """
  391. num_denoising_queries = dn_meta['num_denoising_queries']
  392. if dn_meta is not None:
  393. all_layers_denoising_cls_scores = \
  394. all_layers_cls_scores[:, :, : num_denoising_queries, :]
  395. all_layers_denoising_bbox_preds = \
  396. all_layers_bbox_preds[:, :, : num_denoising_queries, :]
  397. all_layers_matching_cls_scores = \
  398. all_layers_cls_scores[:, :, num_denoising_queries:, :]
  399. all_layers_matching_bbox_preds = \
  400. all_layers_bbox_preds[:, :, num_denoising_queries:, :]
  401. else:
  402. all_layers_denoising_cls_scores = None
  403. all_layers_denoising_bbox_preds = None
  404. all_layers_matching_cls_scores = all_layers_cls_scores
  405. all_layers_matching_bbox_preds = all_layers_bbox_preds
  406. return (all_layers_matching_cls_scores, all_layers_matching_bbox_preds,
  407. all_layers_denoising_cls_scores,
  408. all_layers_denoising_bbox_preds)