pisa_ssd_head.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, List, Union
  3. import torch
  4. from torch import Tensor
  5. from mmdet.registry import MODELS
  6. from mmdet.utils import InstanceList, OptInstanceList
  7. from ..losses import CrossEntropyLoss, SmoothL1Loss, carl_loss, isr_p
  8. from ..utils import multi_apply
  9. from .ssd_head import SSDHead
  10. # TODO: add loss evaluator for SSD
  11. @MODELS.register_module()
  12. class PISASSDHead(SSDHead):
  13. """Implementation of `PISA SSD head <https://arxiv.org/abs/1904.04821>`_
  14. Args:
  15. num_classes (int): Number of categories excluding the background
  16. category.
  17. in_channels (Sequence[int]): Number of channels in the input feature
  18. map.
  19. stacked_convs (int): Number of conv layers in cls and reg tower.
  20. Defaults to 0.
  21. feat_channels (int): Number of hidden channels when stacked_convs
  22. > 0. Defaults to 256.
  23. use_depthwise (bool): Whether to use DepthwiseSeparableConv.
  24. Defaults to False.
  25. conv_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct
  26. and config conv layer. Defaults to None.
  27. norm_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct
  28. and config norm layer. Defaults to None.
  29. act_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct
  30. and config activation layer. Defaults to None.
  31. anchor_generator (:obj:`ConfigDict` or dict): Config dict for anchor
  32. generator.
  33. bbox_coder (:obj:`ConfigDict` or dict): Config of bounding box coder.
  34. reg_decoded_bbox (bool): If true, the regression loss would be
  35. applied directly on decoded bounding boxes, converting both
  36. the predicted boxes and regression targets to absolute
  37. coordinates format. Defaults to False. It should be `True` when
  38. using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
  39. train_cfg (:obj:`ConfigDict` or dict, Optional): Training config of
  40. anchor head.
  41. test_cfg (:obj:`ConfigDict` or dict, Optional): Testing config of
  42. anchor head.
  43. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  44. dict], Optional): Initialization config dict.
  45. """ # noqa: W605
  46. def loss_by_feat(
  47. self,
  48. cls_scores: List[Tensor],
  49. bbox_preds: List[Tensor],
  50. batch_gt_instances: InstanceList,
  51. batch_img_metas: List[dict],
  52. batch_gt_instances_ignore: OptInstanceList = None
  53. ) -> Dict[str, Union[List[Tensor], Tensor]]:
  54. """Compute losses of the head.
  55. Args:
  56. cls_scores (list[Tensor]): Box scores for each scale level
  57. Has shape (N, num_anchors * num_classes, H, W)
  58. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  59. level with shape (N, num_anchors * 4, H, W)
  60. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  61. gt_instance. It usually includes ``bboxes`` and ``labels``
  62. attributes.
  63. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  64. image size, scaling factor, etc.
  65. batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
  66. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  67. data that is ignored during training and testing.
  68. Defaults to None.
  69. Returns:
  70. dict[str, Union[List[Tensor], Tensor]]: A dictionary of loss
  71. components. the dict has components below:
  72. - loss_cls (list[Tensor]): A list containing each feature map \
  73. classification loss.
  74. - loss_bbox (list[Tensor]): A list containing each feature map \
  75. regression loss.
  76. - loss_carl (Tensor): The loss of CARL.
  77. """
  78. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  79. assert len(featmap_sizes) == self.prior_generator.num_levels
  80. device = cls_scores[0].device
  81. anchor_list, valid_flag_list = self.get_anchors(
  82. featmap_sizes, batch_img_metas, device=device)
  83. cls_reg_targets = self.get_targets(
  84. anchor_list,
  85. valid_flag_list,
  86. batch_gt_instances,
  87. batch_img_metas,
  88. batch_gt_instances_ignore=batch_gt_instances_ignore,
  89. unmap_outputs=False,
  90. return_sampling_results=True)
  91. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  92. avg_factor, sampling_results_list) = cls_reg_targets
  93. num_images = len(batch_img_metas)
  94. all_cls_scores = torch.cat([
  95. s.permute(0, 2, 3, 1).reshape(
  96. num_images, -1, self.cls_out_channels) for s in cls_scores
  97. ], 1)
  98. all_labels = torch.cat(labels_list, -1).view(num_images, -1)
  99. all_label_weights = torch.cat(label_weights_list,
  100. -1).view(num_images, -1)
  101. all_bbox_preds = torch.cat([
  102. b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
  103. for b in bbox_preds
  104. ], -2)
  105. all_bbox_targets = torch.cat(bbox_targets_list,
  106. -2).view(num_images, -1, 4)
  107. all_bbox_weights = torch.cat(bbox_weights_list,
  108. -2).view(num_images, -1, 4)
  109. # concat all level anchors to a single tensor
  110. all_anchors = []
  111. for i in range(num_images):
  112. all_anchors.append(torch.cat(anchor_list[i]))
  113. isr_cfg = self.train_cfg.get('isr', None)
  114. all_targets = (all_labels.view(-1), all_label_weights.view(-1),
  115. all_bbox_targets.view(-1,
  116. 4), all_bbox_weights.view(-1, 4))
  117. # apply ISR-P
  118. if isr_cfg is not None:
  119. all_targets = isr_p(
  120. all_cls_scores.view(-1, all_cls_scores.size(-1)),
  121. all_bbox_preds.view(-1, 4),
  122. all_targets,
  123. torch.cat(all_anchors),
  124. sampling_results_list,
  125. loss_cls=CrossEntropyLoss(),
  126. bbox_coder=self.bbox_coder,
  127. **self.train_cfg['isr'],
  128. num_class=self.num_classes)
  129. (new_labels, new_label_weights, new_bbox_targets,
  130. new_bbox_weights) = all_targets
  131. all_labels = new_labels.view(all_labels.shape)
  132. all_label_weights = new_label_weights.view(all_label_weights.shape)
  133. all_bbox_targets = new_bbox_targets.view(all_bbox_targets.shape)
  134. all_bbox_weights = new_bbox_weights.view(all_bbox_weights.shape)
  135. # add CARL loss
  136. carl_loss_cfg = self.train_cfg.get('carl', None)
  137. if carl_loss_cfg is not None:
  138. loss_carl = carl_loss(
  139. all_cls_scores.view(-1, all_cls_scores.size(-1)),
  140. all_targets[0],
  141. all_bbox_preds.view(-1, 4),
  142. all_targets[2],
  143. SmoothL1Loss(beta=1.),
  144. **self.train_cfg['carl'],
  145. avg_factor=avg_factor,
  146. num_class=self.num_classes)
  147. # check NaN and Inf
  148. assert torch.isfinite(all_cls_scores).all().item(), \
  149. 'classification scores become infinite or NaN!'
  150. assert torch.isfinite(all_bbox_preds).all().item(), \
  151. 'bbox predications become infinite or NaN!'
  152. losses_cls, losses_bbox = multi_apply(
  153. self.loss_by_feat_single,
  154. all_cls_scores,
  155. all_bbox_preds,
  156. all_anchors,
  157. all_labels,
  158. all_label_weights,
  159. all_bbox_targets,
  160. all_bbox_weights,
  161. avg_factor=avg_factor)
  162. loss_dict = dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
  163. if carl_loss_cfg is not None:
  164. loss_dict.update(loss_carl)
  165. return loss_dict