ld_head.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Tuple
  3. import torch
  4. from torch import Tensor
  5. from mmdet.registry import MODELS
  6. from mmdet.structures import SampleList
  7. from mmdet.structures.bbox import bbox_overlaps
  8. from mmdet.utils import ConfigType, InstanceList, OptInstanceList, reduce_mean
  9. from ..utils import multi_apply, unpack_gt_instances
  10. from .gfl_head import GFLHead
  11. @MODELS.register_module()
  12. class LDHead(GFLHead):
  13. """Localization distillation Head. (Short description)
  14. It utilizes the learned bbox distributions to transfer the localization
  15. dark knowledge from teacher to student. Original paper: `Localization
  16. Distillation for Object Detection. <https://arxiv.org/abs/2102.12252>`_
  17. Args:
  18. num_classes (int): Number of categories excluding the background
  19. category.
  20. in_channels (int): Number of channels in the input feature map.
  21. loss_ld (:obj:`ConfigDict` or dict): Config of Localization
  22. Distillation Loss (LD), T is the temperature for distillation.
  23. """
  24. def __init__(self,
  25. num_classes: int,
  26. in_channels: int,
  27. loss_ld: ConfigType = dict(
  28. type='LocalizationDistillationLoss',
  29. loss_weight=0.25,
  30. T=10),
  31. **kwargs) -> dict:
  32. super().__init__(
  33. num_classes=num_classes, in_channels=in_channels, **kwargs)
  34. self.loss_ld = MODELS.build(loss_ld)
  35. def loss_by_feat_single(self, anchors: Tensor, cls_score: Tensor,
  36. bbox_pred: Tensor, labels: Tensor,
  37. label_weights: Tensor, bbox_targets: Tensor,
  38. stride: Tuple[int], soft_targets: Tensor,
  39. avg_factor: int):
  40. """Calculate the loss of a single scale level based on the features
  41. extracted by the detection head.
  42. Args:
  43. anchors (Tensor): Box reference for each scale level with shape
  44. (N, num_total_anchors, 4).
  45. cls_score (Tensor): Cls and quality joint scores for each scale
  46. level has shape (N, num_classes, H, W).
  47. bbox_pred (Tensor): Box distribution logits for each scale
  48. level with shape (N, 4*(n+1), H, W), n is max value of integral
  49. set.
  50. labels (Tensor): Labels of each anchors with shape
  51. (N, num_total_anchors).
  52. label_weights (Tensor): Label weights of each anchor with shape
  53. (N, num_total_anchors)
  54. bbox_targets (Tensor): BBox regression targets of each anchor
  55. weight shape (N, num_total_anchors, 4).
  56. stride (tuple): Stride in this scale level.
  57. soft_targets (Tensor): Soft BBox regression targets.
  58. avg_factor (int): Average factor that is used to average
  59. the loss. When using sampling method, avg_factor is usually
  60. the sum of positive and negative priors. When using
  61. `PseudoSampler`, `avg_factor` is usually equal to the number
  62. of positive priors.
  63. Returns:
  64. dict[tuple, Tensor]: Loss components and weight targets.
  65. """
  66. assert stride[0] == stride[1], 'h stride is not equal to w stride!'
  67. anchors = anchors.reshape(-1, 4)
  68. cls_score = cls_score.permute(0, 2, 3,
  69. 1).reshape(-1, self.cls_out_channels)
  70. bbox_pred = bbox_pred.permute(0, 2, 3,
  71. 1).reshape(-1, 4 * (self.reg_max + 1))
  72. soft_targets = soft_targets.permute(0, 2, 3,
  73. 1).reshape(-1,
  74. 4 * (self.reg_max + 1))
  75. bbox_targets = bbox_targets.reshape(-1, 4)
  76. labels = labels.reshape(-1)
  77. label_weights = label_weights.reshape(-1)
  78. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  79. bg_class_ind = self.num_classes
  80. pos_inds = ((labels >= 0)
  81. & (labels < bg_class_ind)).nonzero().squeeze(1)
  82. score = label_weights.new_zeros(labels.shape)
  83. if len(pos_inds) > 0:
  84. pos_bbox_targets = bbox_targets[pos_inds]
  85. pos_bbox_pred = bbox_pred[pos_inds]
  86. pos_anchors = anchors[pos_inds]
  87. pos_anchor_centers = self.anchor_center(pos_anchors) / stride[0]
  88. weight_targets = cls_score.detach().sigmoid()
  89. weight_targets = weight_targets.max(dim=1)[0][pos_inds]
  90. pos_bbox_pred_corners = self.integral(pos_bbox_pred)
  91. pos_decode_bbox_pred = self.bbox_coder.decode(
  92. pos_anchor_centers, pos_bbox_pred_corners)
  93. pos_decode_bbox_targets = pos_bbox_targets / stride[0]
  94. score[pos_inds] = bbox_overlaps(
  95. pos_decode_bbox_pred.detach(),
  96. pos_decode_bbox_targets,
  97. is_aligned=True)
  98. pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1)
  99. pos_soft_targets = soft_targets[pos_inds]
  100. soft_corners = pos_soft_targets.reshape(-1, self.reg_max + 1)
  101. target_corners = self.bbox_coder.encode(pos_anchor_centers,
  102. pos_decode_bbox_targets,
  103. self.reg_max).reshape(-1)
  104. # regression loss
  105. loss_bbox = self.loss_bbox(
  106. pos_decode_bbox_pred,
  107. pos_decode_bbox_targets,
  108. weight=weight_targets,
  109. avg_factor=1.0)
  110. # dfl loss
  111. loss_dfl = self.loss_dfl(
  112. pred_corners,
  113. target_corners,
  114. weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
  115. avg_factor=4.0)
  116. # ld loss
  117. loss_ld = self.loss_ld(
  118. pred_corners,
  119. soft_corners,
  120. weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
  121. avg_factor=4.0)
  122. else:
  123. loss_ld = bbox_pred.sum() * 0
  124. loss_bbox = bbox_pred.sum() * 0
  125. loss_dfl = bbox_pred.sum() * 0
  126. weight_targets = bbox_pred.new_tensor(0)
  127. # cls (qfl) loss
  128. loss_cls = self.loss_cls(
  129. cls_score, (labels, score),
  130. weight=label_weights,
  131. avg_factor=avg_factor)
  132. return loss_cls, loss_bbox, loss_dfl, loss_ld, weight_targets.sum()
  133. def loss(self, x: List[Tensor], out_teacher: Tuple[Tensor],
  134. batch_data_samples: SampleList) -> dict:
  135. """
  136. Args:
  137. x (list[Tensor]): Features from FPN.
  138. out_teacher (tuple[Tensor]): The output of teacher.
  139. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  140. data samples. It usually includes information such
  141. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  142. Returns:
  143. tuple[dict, list]: The loss components and proposals of each image.
  144. - losses (dict[str, Tensor]): A dictionary of loss components.
  145. - proposal_list (list[Tensor]): Proposals of each image.
  146. """
  147. outputs = unpack_gt_instances(batch_data_samples)
  148. batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \
  149. = outputs
  150. outs = self(x)
  151. soft_targets = out_teacher[1]
  152. loss_inputs = outs + (batch_gt_instances, batch_img_metas,
  153. soft_targets)
  154. losses = self.loss_by_feat(
  155. *loss_inputs, batch_gt_instances_ignore=batch_gt_instances_ignore)
  156. return losses
  157. def loss_by_feat(
  158. self,
  159. cls_scores: List[Tensor],
  160. bbox_preds: List[Tensor],
  161. batch_gt_instances: InstanceList,
  162. batch_img_metas: List[dict],
  163. soft_targets: List[Tensor],
  164. batch_gt_instances_ignore: OptInstanceList = None) -> dict:
  165. """Compute losses of the head.
  166. Args:
  167. cls_scores (list[Tensor]): Cls and quality scores for each scale
  168. level has shape (N, num_classes, H, W).
  169. bbox_preds (list[Tensor]): Box distribution logits for each scale
  170. level with shape (N, 4*(n+1), H, W), n is max value of integral
  171. set.
  172. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  173. gt_instance. It usually includes ``bboxes`` and ``labels``
  174. attributes.
  175. soft_targets (list[Tensor]): Soft BBox regression targets.
  176. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  177. image size, scaling factor, etc.
  178. batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
  179. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  180. data that is ignored during training and testing.
  181. Defaults to None.
  182. Returns:
  183. dict[str, Tensor]: A dictionary of loss components.
  184. """
  185. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  186. assert len(featmap_sizes) == self.prior_generator.num_levels
  187. device = cls_scores[0].device
  188. anchor_list, valid_flag_list = self.get_anchors(
  189. featmap_sizes, batch_img_metas, device=device)
  190. cls_reg_targets = self.get_targets(
  191. anchor_list,
  192. valid_flag_list,
  193. batch_gt_instances,
  194. batch_img_metas,
  195. batch_gt_instances_ignore=batch_gt_instances_ignore)
  196. (anchor_list, labels_list, label_weights_list, bbox_targets_list,
  197. bbox_weights_list, avg_factor) = cls_reg_targets
  198. avg_factor = reduce_mean(
  199. torch.tensor(avg_factor, dtype=torch.float, device=device)).item()
  200. losses_cls, losses_bbox, losses_dfl, losses_ld, \
  201. avg_factor = multi_apply(
  202. self.loss_by_feat_single,
  203. anchor_list,
  204. cls_scores,
  205. bbox_preds,
  206. labels_list,
  207. label_weights_list,
  208. bbox_targets_list,
  209. self.prior_generator.strides,
  210. soft_targets,
  211. avg_factor=avg_factor)
  212. avg_factor = sum(avg_factor) + 1e-6
  213. avg_factor = reduce_mean(avg_factor).item()
  214. losses_bbox = [x / avg_factor for x in losses_bbox]
  215. losses_dfl = [x / avg_factor for x in losses_dfl]
  216. return dict(
  217. loss_cls=losses_cls,
  218. loss_bbox=losses_bbox,
  219. loss_dfl=losses_dfl,
  220. loss_ld=losses_ld)