dynamic_roi_head.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Tuple
  3. import numpy as np
  4. import torch
  5. from torch import Tensor
  6. from mmdet.models.losses import SmoothL1Loss
  7. from mmdet.models.task_modules.samplers import SamplingResult
  8. from mmdet.registry import MODELS
  9. from mmdet.structures import SampleList
  10. from mmdet.structures.bbox import bbox2roi
  11. from mmdet.utils import InstanceList
  12. from ..utils.misc import unpack_gt_instances
  13. from .standard_roi_head import StandardRoIHead
  14. EPS = 1e-15
  15. @MODELS.register_module()
  16. class DynamicRoIHead(StandardRoIHead):
  17. """RoI head for `Dynamic R-CNN <https://arxiv.org/abs/2004.06002>`_."""
  18. def __init__(self, **kwargs) -> None:
  19. super().__init__(**kwargs)
  20. assert isinstance(self.bbox_head.loss_bbox, SmoothL1Loss)
  21. # the IoU history of the past `update_iter_interval` iterations
  22. self.iou_history = []
  23. # the beta history of the past `update_iter_interval` iterations
  24. self.beta_history = []
  25. def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
  26. batch_data_samples: SampleList) -> dict:
  27. """Forward function for training.
  28. Args:
  29. x (tuple[Tensor]): List of multi-level img features.
  30. rpn_results_list (list[:obj:`InstanceData`]): List of region
  31. proposals.
  32. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  33. data samples. It usually includes information such
  34. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  35. Returns:
  36. dict[str, Tensor]: a dictionary of loss components
  37. """
  38. assert len(rpn_results_list) == len(batch_data_samples)
  39. outputs = unpack_gt_instances(batch_data_samples)
  40. batch_gt_instances, batch_gt_instances_ignore, _ = outputs
  41. # assign gts and sample proposals
  42. num_imgs = len(batch_data_samples)
  43. sampling_results = []
  44. cur_iou = []
  45. for i in range(num_imgs):
  46. # rename rpn_results.bboxes to rpn_results.priors
  47. rpn_results = rpn_results_list[i]
  48. rpn_results.priors = rpn_results.pop('bboxes')
  49. assign_result = self.bbox_assigner.assign(
  50. rpn_results, batch_gt_instances[i],
  51. batch_gt_instances_ignore[i])
  52. sampling_result = self.bbox_sampler.sample(
  53. assign_result,
  54. rpn_results,
  55. batch_gt_instances[i],
  56. feats=[lvl_feat[i][None] for lvl_feat in x])
  57. # record the `iou_topk`-th largest IoU in an image
  58. iou_topk = min(self.train_cfg.dynamic_rcnn.iou_topk,
  59. len(assign_result.max_overlaps))
  60. ious, _ = torch.topk(assign_result.max_overlaps, iou_topk)
  61. cur_iou.append(ious[-1].item())
  62. sampling_results.append(sampling_result)
  63. # average the current IoUs over images
  64. cur_iou = np.mean(cur_iou)
  65. self.iou_history.append(cur_iou)
  66. losses = dict()
  67. # bbox head forward and loss
  68. if self.with_bbox:
  69. bbox_results = self.bbox_loss(x, sampling_results)
  70. losses.update(bbox_results['loss_bbox'])
  71. # mask head forward and loss
  72. if self.with_mask:
  73. mask_results = self.mask_loss(x, sampling_results,
  74. bbox_results['bbox_feats'],
  75. batch_gt_instances)
  76. losses.update(mask_results['loss_mask'])
  77. # update IoU threshold and SmoothL1 beta
  78. update_iter_interval = self.train_cfg.dynamic_rcnn.update_iter_interval
  79. if len(self.iou_history) % update_iter_interval == 0:
  80. new_iou_thr, new_beta = self.update_hyperparameters()
  81. return losses
  82. def bbox_loss(self, x: Tuple[Tensor],
  83. sampling_results: List[SamplingResult]) -> dict:
  84. """Perform forward propagation and loss calculation of the bbox head on
  85. the features of the upstream network.
  86. Args:
  87. x (tuple[Tensor]): List of multi-level img features.
  88. sampling_results (list["obj:`SamplingResult`]): Sampling results.
  89. Returns:
  90. dict[str, Tensor]: Usually returns a dictionary with keys:
  91. - `cls_score` (Tensor): Classification scores.
  92. - `bbox_pred` (Tensor): Box energies / deltas.
  93. - `bbox_feats` (Tensor): Extract bbox RoI features.
  94. - `loss_bbox` (dict): A dictionary of bbox loss components.
  95. """
  96. rois = bbox2roi([res.priors for res in sampling_results])
  97. bbox_results = self._bbox_forward(x, rois)
  98. bbox_loss_and_target = self.bbox_head.loss_and_target(
  99. cls_score=bbox_results['cls_score'],
  100. bbox_pred=bbox_results['bbox_pred'],
  101. rois=rois,
  102. sampling_results=sampling_results,
  103. rcnn_train_cfg=self.train_cfg)
  104. bbox_results.update(loss_bbox=bbox_loss_and_target['loss_bbox'])
  105. # record the `beta_topk`-th smallest target
  106. # `bbox_targets[2]` and `bbox_targets[3]` stand for bbox_targets
  107. # and bbox_weights, respectively
  108. bbox_targets = bbox_loss_and_target['bbox_targets']
  109. pos_inds = bbox_targets[3][:, 0].nonzero().squeeze(1)
  110. num_pos = len(pos_inds)
  111. num_imgs = len(sampling_results)
  112. if num_pos > 0:
  113. cur_target = bbox_targets[2][pos_inds, :2].abs().mean(dim=1)
  114. beta_topk = min(self.train_cfg.dynamic_rcnn.beta_topk * num_imgs,
  115. num_pos)
  116. cur_target = torch.kthvalue(cur_target, beta_topk)[0].item()
  117. self.beta_history.append(cur_target)
  118. return bbox_results
  119. def update_hyperparameters(self):
  120. """Update hyperparameters like IoU thresholds for assigner and beta for
  121. SmoothL1 loss based on the training statistics.
  122. Returns:
  123. tuple[float]: the updated ``iou_thr`` and ``beta``.
  124. """
  125. new_iou_thr = max(self.train_cfg.dynamic_rcnn.initial_iou,
  126. np.mean(self.iou_history))
  127. self.iou_history = []
  128. self.bbox_assigner.pos_iou_thr = new_iou_thr
  129. self.bbox_assigner.neg_iou_thr = new_iou_thr
  130. self.bbox_assigner.min_pos_iou = new_iou_thr
  131. if (not self.beta_history) or (np.median(self.beta_history) < EPS):
  132. # avoid 0 or too small value for new_beta
  133. new_beta = self.bbox_head.loss_bbox.beta
  134. else:
  135. new_beta = min(self.train_cfg.dynamic_rcnn.initial_beta,
  136. np.median(self.beta_history))
  137. self.beta_history = []
  138. self.bbox_head.loss_bbox.beta = new_beta
  139. return new_iou_thr, new_beta