pisa_retinanet_head.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List
  3. import torch
  4. from torch import Tensor
  5. from mmdet.registry import MODELS
  6. from mmdet.utils import InstanceList, OptInstanceList
  7. from ..losses import carl_loss, isr_p
  8. from ..utils import images_to_levels
  9. from .retina_head import RetinaHead
  10. @MODELS.register_module()
  11. class PISARetinaHead(RetinaHead):
  12. """PISA Retinanet Head.
  13. The head owns the same structure with Retinanet Head, but differs in two
  14. aspects:
  15. 1. Importance-based Sample Reweighting Positive (ISR-P) is applied to
  16. change the positive loss weights.
  17. 2. Classification-aware regression loss is adopted as a third loss.
  18. """
  19. def loss_by_feat(
  20. self,
  21. cls_scores: List[Tensor],
  22. bbox_preds: List[Tensor],
  23. batch_gt_instances: InstanceList,
  24. batch_img_metas: List[dict],
  25. batch_gt_instances_ignore: OptInstanceList = None) -> dict:
  26. """Compute losses of the head.
  27. Args:
  28. cls_scores (list[Tensor]): Box scores for each scale level
  29. Has shape (N, num_anchors * num_classes, H, W)
  30. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  31. level with shape (N, num_anchors * 4, H, W)
  32. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  33. gt_instance. It usually includes ``bboxes`` and ``labels``
  34. attributes.
  35. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  36. image size, scaling factor, etc.
  37. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  38. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  39. data that is ignored during training and testing.
  40. Defaults to None.
  41. Returns:
  42. dict: Loss dict, comprise classification loss, regression loss and
  43. carl loss.
  44. """
  45. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  46. assert len(featmap_sizes) == self.prior_generator.num_levels
  47. device = cls_scores[0].device
  48. anchor_list, valid_flag_list = self.get_anchors(
  49. featmap_sizes, batch_img_metas, device=device)
  50. label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
  51. cls_reg_targets = self.get_targets(
  52. anchor_list,
  53. valid_flag_list,
  54. batch_gt_instances,
  55. batch_img_metas,
  56. batch_gt_instances_ignore=batch_gt_instances_ignore,
  57. return_sampling_results=True)
  58. if cls_reg_targets is None:
  59. return None
  60. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  61. avg_factor, sampling_results_list) = cls_reg_targets
  62. # anchor number of multi levels
  63. num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
  64. # concat all level anchors and flags to a single tensor
  65. concat_anchor_list = []
  66. for i in range(len(anchor_list)):
  67. concat_anchor_list.append(torch.cat(anchor_list[i]))
  68. all_anchor_list = images_to_levels(concat_anchor_list,
  69. num_level_anchors)
  70. num_imgs = len(batch_img_metas)
  71. flatten_cls_scores = [
  72. cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, label_channels)
  73. for cls_score in cls_scores
  74. ]
  75. flatten_cls_scores = torch.cat(
  76. flatten_cls_scores, dim=1).reshape(-1,
  77. flatten_cls_scores[0].size(-1))
  78. flatten_bbox_preds = [
  79. bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
  80. for bbox_pred in bbox_preds
  81. ]
  82. flatten_bbox_preds = torch.cat(
  83. flatten_bbox_preds, dim=1).view(-1, flatten_bbox_preds[0].size(-1))
  84. flatten_labels = torch.cat(labels_list, dim=1).reshape(-1)
  85. flatten_label_weights = torch.cat(
  86. label_weights_list, dim=1).reshape(-1)
  87. flatten_anchors = torch.cat(all_anchor_list, dim=1).reshape(-1, 4)
  88. flatten_bbox_targets = torch.cat(
  89. bbox_targets_list, dim=1).reshape(-1, 4)
  90. flatten_bbox_weights = torch.cat(
  91. bbox_weights_list, dim=1).reshape(-1, 4)
  92. # Apply ISR-P
  93. isr_cfg = self.train_cfg.get('isr', None)
  94. if isr_cfg is not None:
  95. all_targets = (flatten_labels, flatten_label_weights,
  96. flatten_bbox_targets, flatten_bbox_weights)
  97. with torch.no_grad():
  98. all_targets = isr_p(
  99. flatten_cls_scores,
  100. flatten_bbox_preds,
  101. all_targets,
  102. flatten_anchors,
  103. sampling_results_list,
  104. bbox_coder=self.bbox_coder,
  105. loss_cls=self.loss_cls,
  106. num_class=self.num_classes,
  107. **self.train_cfg['isr'])
  108. (flatten_labels, flatten_label_weights, flatten_bbox_targets,
  109. flatten_bbox_weights) = all_targets
  110. # For convenience we compute loss once instead separating by fpn level,
  111. # so that we don't need to separate the weights by level again.
  112. # The result should be the same
  113. losses_cls = self.loss_cls(
  114. flatten_cls_scores,
  115. flatten_labels,
  116. flatten_label_weights,
  117. avg_factor=avg_factor)
  118. losses_bbox = self.loss_bbox(
  119. flatten_bbox_preds,
  120. flatten_bbox_targets,
  121. flatten_bbox_weights,
  122. avg_factor=avg_factor)
  123. loss_dict = dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
  124. # CARL Loss
  125. carl_cfg = self.train_cfg.get('carl', None)
  126. if carl_cfg is not None:
  127. loss_carl = carl_loss(
  128. flatten_cls_scores,
  129. flatten_labels,
  130. flatten_bbox_preds,
  131. flatten_bbox_targets,
  132. self.loss_bbox,
  133. **self.train_cfg['carl'],
  134. avg_factor=avg_factor,
  135. sigmoid=True,
  136. num_class=self.num_classes)
  137. loss_dict.update(loss_carl)
  138. return loss_dict