123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List, Tuple
- import numpy as np
- import torch
- from torch import Tensor
- from mmdet.models.losses import SmoothL1Loss
- from mmdet.models.task_modules.samplers import SamplingResult
- from mmdet.registry import MODELS
- from mmdet.structures import SampleList
- from mmdet.structures.bbox import bbox2roi
- from mmdet.utils import InstanceList
- from ..utils.misc import unpack_gt_instances
- from .standard_roi_head import StandardRoIHead
- EPS = 1e-15
- @MODELS.register_module()
- class DynamicRoIHead(StandardRoIHead):
- """RoI head for `Dynamic R-CNN <https://arxiv.org/abs/2004.06002>`_."""
- def __init__(self, **kwargs) -> None:
- super().__init__(**kwargs)
- assert isinstance(self.bbox_head.loss_bbox, SmoothL1Loss)
- # the IoU history of the past `update_iter_interval` iterations
- self.iou_history = []
- # the beta history of the past `update_iter_interval` iterations
- self.beta_history = []
- def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
- batch_data_samples: SampleList) -> dict:
- """Forward function for training.
- Args:
- x (tuple[Tensor]): List of multi-level img features.
- rpn_results_list (list[:obj:`InstanceData`]): List of region
- proposals.
- batch_data_samples (list[:obj:`DetDataSample`]): The batch
- data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
- Returns:
- dict[str, Tensor]: a dictionary of loss components
- """
- assert len(rpn_results_list) == len(batch_data_samples)
- outputs = unpack_gt_instances(batch_data_samples)
- batch_gt_instances, batch_gt_instances_ignore, _ = outputs
- # assign gts and sample proposals
- num_imgs = len(batch_data_samples)
- sampling_results = []
- cur_iou = []
- for i in range(num_imgs):
- # rename rpn_results.bboxes to rpn_results.priors
- rpn_results = rpn_results_list[i]
- rpn_results.priors = rpn_results.pop('bboxes')
- assign_result = self.bbox_assigner.assign(
- rpn_results, batch_gt_instances[i],
- batch_gt_instances_ignore[i])
- sampling_result = self.bbox_sampler.sample(
- assign_result,
- rpn_results,
- batch_gt_instances[i],
- feats=[lvl_feat[i][None] for lvl_feat in x])
- # record the `iou_topk`-th largest IoU in an image
- iou_topk = min(self.train_cfg.dynamic_rcnn.iou_topk,
- len(assign_result.max_overlaps))
- ious, _ = torch.topk(assign_result.max_overlaps, iou_topk)
- cur_iou.append(ious[-1].item())
- sampling_results.append(sampling_result)
- # average the current IoUs over images
- cur_iou = np.mean(cur_iou)
- self.iou_history.append(cur_iou)
- losses = dict()
- # bbox head forward and loss
- if self.with_bbox:
- bbox_results = self.bbox_loss(x, sampling_results)
- losses.update(bbox_results['loss_bbox'])
- # mask head forward and loss
- if self.with_mask:
- mask_results = self.mask_loss(x, sampling_results,
- bbox_results['bbox_feats'],
- batch_gt_instances)
- losses.update(mask_results['loss_mask'])
- # update IoU threshold and SmoothL1 beta
- update_iter_interval = self.train_cfg.dynamic_rcnn.update_iter_interval
- if len(self.iou_history) % update_iter_interval == 0:
- new_iou_thr, new_beta = self.update_hyperparameters()
- return losses
- def bbox_loss(self, x: Tuple[Tensor],
- sampling_results: List[SamplingResult]) -> dict:
- """Perform forward propagation and loss calculation of the bbox head on
- the features of the upstream network.
- Args:
- x (tuple[Tensor]): List of multi-level img features.
- sampling_results (list["obj:`SamplingResult`]): Sampling results.
- Returns:
- dict[str, Tensor]: Usually returns a dictionary with keys:
- - `cls_score` (Tensor): Classification scores.
- - `bbox_pred` (Tensor): Box energies / deltas.
- - `bbox_feats` (Tensor): Extract bbox RoI features.
- - `loss_bbox` (dict): A dictionary of bbox loss components.
- """
- rois = bbox2roi([res.priors for res in sampling_results])
- bbox_results = self._bbox_forward(x, rois)
- bbox_loss_and_target = self.bbox_head.loss_and_target(
- cls_score=bbox_results['cls_score'],
- bbox_pred=bbox_results['bbox_pred'],
- rois=rois,
- sampling_results=sampling_results,
- rcnn_train_cfg=self.train_cfg)
- bbox_results.update(loss_bbox=bbox_loss_and_target['loss_bbox'])
- # record the `beta_topk`-th smallest target
- # `bbox_targets[2]` and `bbox_targets[3]` stand for bbox_targets
- # and bbox_weights, respectively
- bbox_targets = bbox_loss_and_target['bbox_targets']
- pos_inds = bbox_targets[3][:, 0].nonzero().squeeze(1)
- num_pos = len(pos_inds)
- num_imgs = len(sampling_results)
- if num_pos > 0:
- cur_target = bbox_targets[2][pos_inds, :2].abs().mean(dim=1)
- beta_topk = min(self.train_cfg.dynamic_rcnn.beta_topk * num_imgs,
- num_pos)
- cur_target = torch.kthvalue(cur_target, beta_topk)[0].item()
- self.beta_history.append(cur_target)
- return bbox_results
- def update_hyperparameters(self):
- """Update hyperparameters like IoU thresholds for assigner and beta for
- SmoothL1 loss based on the training statistics.
- Returns:
- tuple[float]: the updated ``iou_thr`` and ``beta``.
- """
- new_iou_thr = max(self.train_cfg.dynamic_rcnn.initial_iou,
- np.mean(self.iou_history))
- self.iou_history = []
- self.bbox_assigner.pos_iou_thr = new_iou_thr
- self.bbox_assigner.neg_iou_thr = new_iou_thr
- self.bbox_assigner.min_pos_iou = new_iou_thr
- if (not self.beta_history) or (np.median(self.beta_history) < EPS):
- # avoid 0 or too small value for new_beta
- new_beta = self.bbox_head.loss_bbox.beta
- else:
- new_beta = min(self.train_cfg.dynamic_rcnn.initial_beta,
- np.median(self.beta_history))
- self.beta_history = []
- self.bbox_head.loss_bbox.beta = new_beta
- return new_iou_thr, new_beta
|