pisa_roi_head.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Tuple
  3. from torch import Tensor
  4. from mmdet.models.task_modules import SamplingResult
  5. from mmdet.registry import MODELS
  6. from mmdet.structures import DetDataSample
  7. from mmdet.structures.bbox import bbox2roi
  8. from mmdet.utils import InstanceList
  9. from ..losses.pisa_loss import carl_loss, isr_p
  10. from ..utils import unpack_gt_instances
  11. from .standard_roi_head import StandardRoIHead
  12. @MODELS.register_module()
  13. class PISARoIHead(StandardRoIHead):
  14. r"""The RoI head for `Prime Sample Attention in Object Detection
  15. <https://arxiv.org/abs/1904.04821>`_."""
  16. def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
  17. batch_data_samples: List[DetDataSample]) -> dict:
  18. """Perform forward propagation and loss calculation of the detection
  19. roi on the features of the upstream network.
  20. Args:
  21. x (tuple[Tensor]): List of multi-level img features.
  22. rpn_results_list (list[:obj:`InstanceData`]): List of region
  23. proposals.
  24. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  25. data samples. It usually includes information such
  26. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  27. Returns:
  28. dict[str, Tensor]: A dictionary of loss components
  29. """
  30. assert len(rpn_results_list) == len(batch_data_samples)
  31. outputs = unpack_gt_instances(batch_data_samples)
  32. batch_gt_instances, batch_gt_instances_ignore, _ = outputs
  33. # assign gts and sample proposals
  34. num_imgs = len(batch_data_samples)
  35. sampling_results = []
  36. neg_label_weights = []
  37. for i in range(num_imgs):
  38. # rename rpn_results.bboxes to rpn_results.priors
  39. rpn_results = rpn_results_list[i]
  40. rpn_results.priors = rpn_results.pop('bboxes')
  41. assign_result = self.bbox_assigner.assign(
  42. rpn_results, batch_gt_instances[i],
  43. batch_gt_instances_ignore[i])
  44. sampling_result = self.bbox_sampler.sample(
  45. assign_result,
  46. rpn_results,
  47. batch_gt_instances[i],
  48. feats=[lvl_feat[i][None] for lvl_feat in x])
  49. if isinstance(sampling_result, tuple):
  50. sampling_result, neg_label_weight = sampling_result
  51. sampling_results.append(sampling_result)
  52. neg_label_weights.append(neg_label_weight)
  53. losses = dict()
  54. # bbox head forward and loss
  55. if self.with_bbox:
  56. bbox_results = self.bbox_loss(
  57. x, sampling_results, neg_label_weights=neg_label_weights)
  58. losses.update(bbox_results['loss_bbox'])
  59. # mask head forward and loss
  60. if self.with_mask:
  61. mask_results = self.mask_loss(x, sampling_results,
  62. bbox_results['bbox_feats'],
  63. batch_gt_instances)
  64. losses.update(mask_results['loss_mask'])
  65. return losses
  66. def bbox_loss(self,
  67. x: Tuple[Tensor],
  68. sampling_results: List[SamplingResult],
  69. neg_label_weights: List[Tensor] = None) -> dict:
  70. """Perform forward propagation and loss calculation of the bbox head on
  71. the features of the upstream network.
  72. Args:
  73. x (tuple[Tensor]): List of multi-level img features.
  74. sampling_results (list["obj:`SamplingResult`]): Sampling results.
  75. Returns:
  76. dict[str, Tensor]: Usually returns a dictionary with keys:
  77. - `cls_score` (Tensor): Classification scores.
  78. - `bbox_pred` (Tensor): Box energies / deltas.
  79. - `bbox_feats` (Tensor): Extract bbox RoI features.
  80. - `loss_bbox` (dict): A dictionary of bbox loss components.
  81. """
  82. rois = bbox2roi([res.priors for res in sampling_results])
  83. bbox_results = self._bbox_forward(x, rois)
  84. bbox_targets = self.bbox_head.get_targets(sampling_results,
  85. self.train_cfg)
  86. # neg_label_weights obtained by sampler is image-wise, mapping back to
  87. # the corresponding location in label weights
  88. if neg_label_weights[0] is not None:
  89. label_weights = bbox_targets[1]
  90. cur_num_rois = 0
  91. for i in range(len(sampling_results)):
  92. num_pos = sampling_results[i].pos_inds.size(0)
  93. num_neg = sampling_results[i].neg_inds.size(0)
  94. label_weights[cur_num_rois + num_pos:cur_num_rois + num_pos +
  95. num_neg] = neg_label_weights[i]
  96. cur_num_rois += num_pos + num_neg
  97. cls_score = bbox_results['cls_score']
  98. bbox_pred = bbox_results['bbox_pred']
  99. # Apply ISR-P
  100. isr_cfg = self.train_cfg.get('isr', None)
  101. if isr_cfg is not None:
  102. bbox_targets = isr_p(
  103. cls_score,
  104. bbox_pred,
  105. bbox_targets,
  106. rois,
  107. sampling_results,
  108. self.bbox_head.loss_cls,
  109. self.bbox_head.bbox_coder,
  110. **isr_cfg,
  111. num_class=self.bbox_head.num_classes)
  112. loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, rois,
  113. *bbox_targets)
  114. # Add CARL Loss
  115. carl_cfg = self.train_cfg.get('carl', None)
  116. if carl_cfg is not None:
  117. loss_carl = carl_loss(
  118. cls_score,
  119. bbox_targets[0],
  120. bbox_pred,
  121. bbox_targets[2],
  122. self.bbox_head.loss_bbox,
  123. **carl_cfg,
  124. num_class=self.bbox_head.num_classes)
  125. loss_bbox.update(loss_carl)
  126. bbox_results.update(loss_bbox=loss_bbox)
  127. return bbox_results