lad_head.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional
  3. import torch
  4. from torch import Tensor
  5. from mmdet.registry import MODELS
  6. from mmdet.structures import SampleList
  7. from mmdet.structures.bbox import bbox_overlaps
  8. from mmdet.utils import InstanceList, OptInstanceList
  9. from ..utils import levels_to_images, multi_apply, unpack_gt_instances
  10. from .paa_head import PAAHead
  11. @MODELS.register_module()
  12. class LADHead(PAAHead):
  13. """Label Assignment Head from the paper: `Improving Object Detection by
  14. Label Assignment Distillation <https://arxiv.org/pdf/2108.10520.pdf>`_"""
  15. def get_label_assignment(
  16. self,
  17. cls_scores: List[Tensor],
  18. bbox_preds: List[Tensor],
  19. iou_preds: List[Tensor],
  20. batch_gt_instances: InstanceList,
  21. batch_img_metas: List[dict],
  22. batch_gt_instances_ignore: OptInstanceList = None) -> tuple:
  23. """Get label assignment (from teacher).
  24. Args:
  25. cls_scores (list[Tensor]): Box scores for each scale level
  26. Has shape (N, num_anchors * num_classes, H, W)
  27. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  28. level with shape (N, num_anchors * 4, H, W)
  29. iou_preds (list[Tensor]): iou_preds for each scale
  30. level with shape (N, num_anchors * 1, H, W)
  31. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  32. gt_instance. It usually includes ``bboxes`` and ``labels``
  33. attributes.
  34. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  35. image size, scaling factor, etc.
  36. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  37. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  38. data that is ignored during training and testing.
  39. Defaults to None.
  40. Returns:
  41. tuple: Returns a tuple containing label assignment variables.
  42. - labels (Tensor): Labels of all anchors, each with
  43. shape (num_anchors,).
  44. - labels_weight (Tensor): Label weights of all anchor.
  45. each with shape (num_anchors,).
  46. - bboxes_target (Tensor): BBox targets of all anchors.
  47. each with shape (num_anchors, 4).
  48. - bboxes_weight (Tensor): BBox weights of all anchors.
  49. each with shape (num_anchors, 4).
  50. - pos_inds_flatten (Tensor): Contains all index of positive
  51. sample in all anchor.
  52. - pos_anchors (Tensor): Positive anchors.
  53. - num_pos (int): Number of positive anchors.
  54. """
  55. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  56. assert len(featmap_sizes) == self.prior_generator.num_levels
  57. device = cls_scores[0].device
  58. anchor_list, valid_flag_list = self.get_anchors(
  59. featmap_sizes, batch_img_metas, device=device)
  60. cls_reg_targets = self.get_targets(
  61. anchor_list,
  62. valid_flag_list,
  63. batch_gt_instances,
  64. batch_img_metas,
  65. batch_gt_instances_ignore=batch_gt_instances_ignore,
  66. )
  67. (labels, labels_weight, bboxes_target, bboxes_weight, pos_inds,
  68. pos_gt_index) = cls_reg_targets
  69. cls_scores = levels_to_images(cls_scores)
  70. cls_scores = [
  71. item.reshape(-1, self.cls_out_channels) for item in cls_scores
  72. ]
  73. bbox_preds = levels_to_images(bbox_preds)
  74. bbox_preds = [item.reshape(-1, 4) for item in bbox_preds]
  75. pos_losses_list, = multi_apply(self.get_pos_loss, anchor_list,
  76. cls_scores, bbox_preds, labels,
  77. labels_weight, bboxes_target,
  78. bboxes_weight, pos_inds)
  79. with torch.no_grad():
  80. reassign_labels, reassign_label_weight, \
  81. reassign_bbox_weights, num_pos = multi_apply(
  82. self.paa_reassign,
  83. pos_losses_list,
  84. labels,
  85. labels_weight,
  86. bboxes_weight,
  87. pos_inds,
  88. pos_gt_index,
  89. anchor_list)
  90. num_pos = sum(num_pos)
  91. # convert all tensor list to a flatten tensor
  92. labels = torch.cat(reassign_labels, 0).view(-1)
  93. flatten_anchors = torch.cat(
  94. [torch.cat(item, 0) for item in anchor_list])
  95. labels_weight = torch.cat(reassign_label_weight, 0).view(-1)
  96. bboxes_target = torch.cat(bboxes_target,
  97. 0).view(-1, bboxes_target[0].size(-1))
  98. pos_inds_flatten = ((labels >= 0)
  99. &
  100. (labels < self.num_classes)).nonzero().reshape(-1)
  101. if num_pos:
  102. pos_anchors = flatten_anchors[pos_inds_flatten]
  103. else:
  104. pos_anchors = None
  105. label_assignment_results = (labels, labels_weight, bboxes_target,
  106. bboxes_weight, pos_inds_flatten,
  107. pos_anchors, num_pos)
  108. return label_assignment_results
  109. def loss(self, x: List[Tensor], label_assignment_results: tuple,
  110. batch_data_samples: SampleList) -> dict:
  111. """Forward train with the available label assignment (student receives
  112. from teacher).
  113. Args:
  114. x (list[Tensor]): Features from FPN.
  115. label_assignment_results (tuple): As the outputs defined in the
  116. function `self.get_label_assignment`.
  117. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  118. data samples. It usually includes information such
  119. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  120. Returns:
  121. losses: (dict[str, Tensor]): A dictionary of loss components.
  122. """
  123. outputs = unpack_gt_instances(batch_data_samples)
  124. batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \
  125. = outputs
  126. outs = self(x)
  127. loss_inputs = outs + (batch_gt_instances, batch_img_metas)
  128. losses = self.loss_by_feat(
  129. *loss_inputs,
  130. batch_gt_instances_ignore=batch_gt_instances_ignore,
  131. label_assignment_results=label_assignment_results)
  132. return losses
  133. def loss_by_feat(self,
  134. cls_scores: List[Tensor],
  135. bbox_preds: List[Tensor],
  136. iou_preds: List[Tensor],
  137. batch_gt_instances: InstanceList,
  138. batch_img_metas: List[dict],
  139. batch_gt_instances_ignore: OptInstanceList = None,
  140. label_assignment_results: Optional[tuple] = None) -> dict:
  141. """Compute losses of the head.
  142. Args:
  143. cls_scores (list[Tensor]): Box scores for each scale level
  144. Has shape (N, num_anchors * num_classes, H, W)
  145. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  146. level with shape (N, num_anchors * 4, H, W)
  147. iou_preds (list[Tensor]): iou_preds for each scale
  148. level with shape (N, num_anchors * 1, H, W)
  149. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  150. gt_instance. It usually includes ``bboxes`` and ``labels``
  151. attributes.
  152. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  153. image size, scaling factor, etc.
  154. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  155. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  156. data that is ignored during training and testing.
  157. Defaults to None.
  158. label_assignment_results (tuple, optional): As the outputs defined
  159. in the function `self.get_
  160. label_assignment`.
  161. Returns:
  162. dict[str, Tensor]: A dictionary of loss gmm_assignment.
  163. """
  164. (labels, labels_weight, bboxes_target, bboxes_weight, pos_inds_flatten,
  165. pos_anchors, num_pos) = label_assignment_results
  166. cls_scores = levels_to_images(cls_scores)
  167. cls_scores = [
  168. item.reshape(-1, self.cls_out_channels) for item in cls_scores
  169. ]
  170. bbox_preds = levels_to_images(bbox_preds)
  171. bbox_preds = [item.reshape(-1, 4) for item in bbox_preds]
  172. iou_preds = levels_to_images(iou_preds)
  173. iou_preds = [item.reshape(-1, 1) for item in iou_preds]
  174. # convert all tensor list to a flatten tensor
  175. cls_scores = torch.cat(cls_scores, 0).view(-1, cls_scores[0].size(-1))
  176. bbox_preds = torch.cat(bbox_preds, 0).view(-1, bbox_preds[0].size(-1))
  177. iou_preds = torch.cat(iou_preds, 0).view(-1, iou_preds[0].size(-1))
  178. losses_cls = self.loss_cls(
  179. cls_scores,
  180. labels,
  181. labels_weight,
  182. avg_factor=max(num_pos, len(batch_img_metas))) # avoid num_pos=0
  183. if num_pos:
  184. pos_bbox_pred = self.bbox_coder.decode(
  185. pos_anchors, bbox_preds[pos_inds_flatten])
  186. pos_bbox_target = bboxes_target[pos_inds_flatten]
  187. iou_target = bbox_overlaps(
  188. pos_bbox_pred.detach(), pos_bbox_target, is_aligned=True)
  189. losses_iou = self.loss_centerness(
  190. iou_preds[pos_inds_flatten],
  191. iou_target.unsqueeze(-1),
  192. avg_factor=num_pos)
  193. losses_bbox = self.loss_bbox(
  194. pos_bbox_pred, pos_bbox_target, avg_factor=num_pos)
  195. else:
  196. losses_iou = iou_preds.sum() * 0
  197. losses_bbox = bbox_preds.sum() * 0
  198. return dict(
  199. loss_cls=losses_cls, loss_bbox=losses_bbox, loss_iou=losses_iou)