# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from mmengine.config import ConfigDict from mmengine.model import BaseModule from mmengine.structures import InstanceData from torch import Tensor from torch.nn.modules.utils import _pair from mmdet.models.layers import multiclass_nms from mmdet.models.losses import accuracy from mmdet.models.task_modules.samplers import SamplingResult from mmdet.models.utils import empty_instances, multi_apply from mmdet.registry import MODELS, TASK_UTILS from mmdet.structures.bbox import get_box_tensor, scale_boxes from mmdet.utils import ConfigType, InstanceList, OptMultiConfig @MODELS.register_module() class BBoxHead(BaseModule): """Simplest RoI head, with only two fc layers for classification and regression respectively.""" def __init__(self, with_avg_pool: bool = False, with_cls: bool = True, with_reg: bool = True, roi_feat_size: int = 7, in_channels: int = 256, num_classes: int = 80, bbox_coder: ConfigType = dict( type='DeltaXYWHBBoxCoder', clip_border=True, target_means=[0., 0., 0., 0.], target_stds=[0.1, 0.1, 0.2, 0.2]), predict_box_type: str = 'hbox', reg_class_agnostic: bool = False, reg_decoded_bbox: bool = False, reg_predictor_cfg: ConfigType = dict(type='Linear'), cls_predictor_cfg: ConfigType = dict(type='Linear'), loss_cls: ConfigType = dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_bbox: ConfigType = dict( type='SmoothL1Loss', beta=1.0, loss_weight=1.0), init_cfg: OptMultiConfig = None) -> None: super().__init__(init_cfg=init_cfg) assert with_cls or with_reg self.with_avg_pool = with_avg_pool self.with_cls = with_cls self.with_reg = with_reg self.roi_feat_size = _pair(roi_feat_size) self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1] self.in_channels = in_channels self.num_classes = num_classes self.predict_box_type = predict_box_type self.reg_class_agnostic = reg_class_agnostic self.reg_decoded_bbox = reg_decoded_bbox self.reg_predictor_cfg = reg_predictor_cfg self.cls_predictor_cfg = cls_predictor_cfg self.bbox_coder = TASK_UTILS.build(bbox_coder) self.loss_cls = MODELS.build(loss_cls) self.loss_bbox = MODELS.build(loss_bbox) in_channels = self.in_channels if self.with_avg_pool: self.avg_pool = nn.AvgPool2d(self.roi_feat_size) else: in_channels *= self.roi_feat_area if self.with_cls: # need to add background class if self.custom_cls_channels: cls_channels = self.loss_cls.get_cls_channels(self.num_classes) else: cls_channels = num_classes + 1 cls_predictor_cfg_ = self.cls_predictor_cfg.copy() cls_predictor_cfg_.update( in_features=in_channels, out_features=cls_channels) self.fc_cls = MODELS.build(cls_predictor_cfg_) if self.with_reg: box_dim = self.bbox_coder.encode_size out_dim_reg = box_dim if reg_class_agnostic else \ box_dim * num_classes reg_predictor_cfg_ = self.reg_predictor_cfg.copy() if isinstance(reg_predictor_cfg_, (dict, ConfigDict)): reg_predictor_cfg_.update( in_features=in_channels, out_features=out_dim_reg) self.fc_reg = MODELS.build(reg_predictor_cfg_) self.debug_imgs = None if init_cfg is None: self.init_cfg = [] if self.with_cls: self.init_cfg += [ dict( type='Normal', std=0.01, override=dict(name='fc_cls')) ] if self.with_reg: self.init_cfg += [ dict( type='Normal', std=0.001, override=dict(name='fc_reg')) ] # TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead @property def custom_cls_channels(self) -> bool: """get custom_cls_channels from loss_cls.""" return getattr(self.loss_cls, 'custom_cls_channels', False) # TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead @property def custom_activation(self) -> bool: """get custom_activation from loss_cls.""" return getattr(self.loss_cls, 'custom_activation', False) # TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead @property def custom_accuracy(self) -> bool: """get custom_accuracy from loss_cls.""" return getattr(self.loss_cls, 'custom_accuracy', False) def forward(self, x: Tuple[Tensor]) -> tuple: """Forward features from the upstream network. Args: x (tuple[Tensor]): Features from the upstream network, each is a 4D-tensor. Returns: tuple: A tuple of classification scores and bbox prediction. - cls_score (Tensor): Classification scores for all scale levels, each is a 4D-tensor, the channels number is num_base_priors * num_classes. - bbox_pred (Tensor): Box energies / deltas for all scale levels, each is a 4D-tensor, the channels number is num_base_priors * 4. """ if self.with_avg_pool: if x.numel() > 0: x = self.avg_pool(x) x = x.view(x.size(0), -1) else: # avg_pool does not support empty tensor, # so use torch.mean instead it x = torch.mean(x, dim=(-1, -2)) cls_score = self.fc_cls(x) if self.with_cls else None bbox_pred = self.fc_reg(x) if self.with_reg else None return cls_score, bbox_pred def _get_targets_single(self, pos_priors: Tensor, neg_priors: Tensor, pos_gt_bboxes: Tensor, pos_gt_labels: Tensor, cfg: ConfigDict) -> tuple: """Calculate the ground truth for proposals in the single image according to the sampling results. Args: pos_priors (Tensor): Contains all the positive boxes, has shape (num_pos, 4), the last dimension 4 represents [tl_x, tl_y, br_x, br_y]. neg_priors (Tensor): Contains all the negative boxes, has shape (num_neg, 4), the last dimension 4 represents [tl_x, tl_y, br_x, br_y]. pos_gt_bboxes (Tensor): Contains gt_boxes for all positive samples, has shape (num_pos, 4), the last dimension 4 represents [tl_x, tl_y, br_x, br_y]. pos_gt_labels (Tensor): Contains gt_labels for all positive samples, has shape (num_pos, ). cfg (obj:`ConfigDict`): `train_cfg` of R-CNN. Returns: Tuple[Tensor]: Ground truth for proposals in a single image. Containing the following Tensors: - labels(Tensor): Gt_labels for all proposals, has shape (num_proposals,). - label_weights(Tensor): Labels_weights for all proposals, has shape (num_proposals,). - bbox_targets(Tensor):Regression target for all proposals, has shape (num_proposals, 4), the last dimension 4 represents [tl_x, tl_y, br_x, br_y]. - bbox_weights(Tensor):Regression weights for all proposals, has shape (num_proposals, 4). """ num_pos = pos_priors.size(0) num_neg = neg_priors.size(0) num_samples = num_pos + num_neg # original implementation uses new_zeros since BG are set to be 0 # now use empty & fill because BG cat_id = num_classes, # FG cat_id = [0, num_classes-1] labels = pos_priors.new_full((num_samples, ), self.num_classes, dtype=torch.long) reg_dim = pos_gt_bboxes.size(-1) if self.reg_decoded_bbox \ else self.bbox_coder.encode_size label_weights = pos_priors.new_zeros(num_samples) bbox_targets = pos_priors.new_zeros(num_samples, reg_dim) bbox_weights = pos_priors.new_zeros(num_samples, reg_dim) if num_pos > 0: labels[:num_pos] = pos_gt_labels pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight label_weights[:num_pos] = pos_weight if not self.reg_decoded_bbox: pos_bbox_targets = self.bbox_coder.encode( pos_priors, pos_gt_bboxes) else: # When the regression loss (e.g. `IouLoss`, `GIouLoss`) # is applied directly on the decoded bounding boxes, both # the predicted boxes and regression targets should be with # absolute coordinate format. pos_bbox_targets = get_box_tensor(pos_gt_bboxes) bbox_targets[:num_pos, :] = pos_bbox_targets bbox_weights[:num_pos, :] = 1 if num_neg > 0: label_weights[-num_neg:] = 1.0 return labels, label_weights, bbox_targets, bbox_weights def get_targets(self, sampling_results: List[SamplingResult], rcnn_train_cfg: ConfigDict, concat: bool = True) -> tuple: """Calculate the ground truth for all samples in a batch according to the sampling_results. Almost the same as the implementation in bbox_head, we passed additional parameters pos_inds_list and neg_inds_list to `_get_targets_single` function. Args: sampling_results (List[obj:SamplingResult]): Assign results of all images in a batch after sampling. rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. concat (bool): Whether to concatenate the results of all the images in a single batch. Returns: Tuple[Tensor]: Ground truth for proposals in a single image. Containing the following list of Tensors: - labels (list[Tensor],Tensor): Gt_labels for all proposals in a batch, each tensor in list has shape (num_proposals,) when `concat=False`, otherwise just a single tensor has shape (num_all_proposals,). - label_weights (list[Tensor]): Labels_weights for all proposals in a batch, each tensor in list has shape (num_proposals,) when `concat=False`, otherwise just a single tensor has shape (num_all_proposals,). - bbox_targets (list[Tensor],Tensor): Regression target for all proposals in a batch, each tensor in list has shape (num_proposals, 4) when `concat=False`, otherwise just a single tensor has shape (num_all_proposals, 4), the last dimension 4 represents [tl_x, tl_y, br_x, br_y]. - bbox_weights (list[tensor],Tensor): Regression weights for all proposals in a batch, each tensor in list has shape (num_proposals, 4) when `concat=False`, otherwise just a single tensor has shape (num_all_proposals, 4). """ pos_priors_list = [res.pos_priors for res in sampling_results] neg_priors_list = [res.neg_priors for res in sampling_results] pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results] pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results] labels, label_weights, bbox_targets, bbox_weights = multi_apply( self._get_targets_single, pos_priors_list, neg_priors_list, pos_gt_bboxes_list, pos_gt_labels_list, cfg=rcnn_train_cfg) if concat: labels = torch.cat(labels, 0) label_weights = torch.cat(label_weights, 0) bbox_targets = torch.cat(bbox_targets, 0) bbox_weights = torch.cat(bbox_weights, 0) return labels, label_weights, bbox_targets, bbox_weights def loss_and_target(self, cls_score: Tensor, bbox_pred: Tensor, rois: Tensor, sampling_results: List[SamplingResult], rcnn_train_cfg: ConfigDict, concat: bool = True, reduction_override: Optional[str] = None) -> dict: """Calculate the loss based on the features extracted by the bbox head. Args: cls_score (Tensor): Classification prediction results of all class, has shape (batch_size * num_proposals_single_image, num_classes) bbox_pred (Tensor): Regression prediction results, has shape (batch_size * num_proposals_single_image, 4), the last dimension 4 represents [tl_x, tl_y, br_x, br_y]. rois (Tensor): RoIs with the shape (batch_size * num_proposals_single_image, 5) where the first column indicates batch id of each RoI. sampling_results (List[obj:SamplingResult]): Assign results of all images in a batch after sampling. rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. concat (bool): Whether to concatenate the results of all the images in a single batch. Defaults to True. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Options are "none", "mean" and "sum". Defaults to None, Returns: dict: A dictionary of loss and targets components. The targets are only used for cascade rcnn. """ cls_reg_targets = self.get_targets( sampling_results, rcnn_train_cfg, concat=concat) losses = self.loss( cls_score, bbox_pred, rois, *cls_reg_targets, reduction_override=reduction_override) # cls_reg_targets is only for cascade rcnn return dict(loss_bbox=losses, bbox_targets=cls_reg_targets) def loss(self, cls_score: Tensor, bbox_pred: Tensor, rois: Tensor, labels: Tensor, label_weights: Tensor, bbox_targets: Tensor, bbox_weights: Tensor, reduction_override: Optional[str] = None) -> dict: """Calculate the loss based on the network predictions and targets. Args: cls_score (Tensor): Classification prediction results of all class, has shape (batch_size * num_proposals_single_image, num_classes) bbox_pred (Tensor): Regression prediction results, has shape (batch_size * num_proposals_single_image, 4), the last dimension 4 represents [tl_x, tl_y, br_x, br_y]. rois (Tensor): RoIs with the shape (batch_size * num_proposals_single_image, 5) where the first column indicates batch id of each RoI. labels (Tensor): Gt_labels for all proposals in a batch, has shape (batch_size * num_proposals_single_image, ). label_weights (Tensor): Labels_weights for all proposals in a batch, has shape (batch_size * num_proposals_single_image, ). bbox_targets (Tensor): Regression target for all proposals in a batch, has shape (batch_size * num_proposals_single_image, 4), the last dimension 4 represents [tl_x, tl_y, br_x, br_y]. bbox_weights (Tensor): Regression weights for all proposals in a batch, has shape (batch_size * num_proposals_single_image, 4). reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Options are "none", "mean" and "sum". Defaults to None, Returns: dict: A dictionary of loss. """ losses = dict() if cls_score is not None: avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.) if cls_score.numel() > 0: loss_cls_ = self.loss_cls( cls_score, labels, label_weights, avg_factor=avg_factor, reduction_override=reduction_override) if isinstance(loss_cls_, dict): losses.update(loss_cls_) else: losses['loss_cls'] = loss_cls_ if self.custom_activation: acc_ = self.loss_cls.get_accuracy(cls_score, labels) losses.update(acc_) else: losses['acc'] = accuracy(cls_score, labels) if bbox_pred is not None: bg_class_ind = self.num_classes # 0~self.num_classes-1 are FG, self.num_classes is BG pos_inds = (labels >= 0) & (labels < bg_class_ind) # do not perform bounding box regression for BG anymore. if pos_inds.any(): if self.reg_decoded_bbox: # When the regression loss (e.g. `IouLoss`, # `GIouLoss`, `DIouLoss`) is applied directly on # the decoded bounding boxes, it decodes the # already encoded coordinates to absolute format. bbox_pred = self.bbox_coder.decode(rois[:, 1:], bbox_pred) bbox_pred = get_box_tensor(bbox_pred) if self.reg_class_agnostic: pos_bbox_pred = bbox_pred.view( bbox_pred.size(0), -1)[pos_inds.type(torch.bool)] else: pos_bbox_pred = bbox_pred.view( bbox_pred.size(0), self.num_classes, -1)[pos_inds.type(torch.bool), labels[pos_inds.type(torch.bool)]] losses['loss_bbox'] = self.loss_bbox( pos_bbox_pred, bbox_targets[pos_inds.type(torch.bool)], bbox_weights[pos_inds.type(torch.bool)], avg_factor=bbox_targets.size(0), reduction_override=reduction_override) else: losses['loss_bbox'] = bbox_pred[pos_inds].sum() return losses def predict_by_feat(self, rois: Tuple[Tensor], cls_scores: Tuple[Tensor], bbox_preds: Tuple[Tensor], batch_img_metas: List[dict], rcnn_test_cfg: Optional[ConfigDict] = None, rescale: bool = False) -> InstanceList: """Transform a batch of output features extracted from the head into bbox results. Args: rois (tuple[Tensor]): Tuple of boxes to be transformed. Each has shape (num_boxes, 5). last dimension 5 arrange as (batch_index, x1, y1, x2, y2). cls_scores (tuple[Tensor]): Tuple of box scores, each has shape (num_boxes, num_classes + 1). bbox_preds (tuple[Tensor]): Tuple of box energies / deltas, each has shape (num_boxes, num_classes * 4). batch_img_metas (list[dict]): List of image information. rcnn_test_cfg (obj:`ConfigDict`, optional): `test_cfg` of R-CNN. Defaults to None. rescale (bool): If True, return boxes in original image space. Defaults to False. Returns: list[:obj:`InstanceData`]: Instance segmentation results of each image after the post process. Each item usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). """ assert len(cls_scores) == len(bbox_preds) result_list = [] for img_id in range(len(batch_img_metas)): img_meta = batch_img_metas[img_id] results = self._predict_by_feat_single( roi=rois[img_id], cls_score=cls_scores[img_id], bbox_pred=bbox_preds[img_id], img_meta=img_meta, rescale=rescale, rcnn_test_cfg=rcnn_test_cfg) result_list.append(results) return result_list def _predict_by_feat_single( self, roi: Tensor, cls_score: Tensor, bbox_pred: Tensor, img_meta: dict, rescale: bool = False, rcnn_test_cfg: Optional[ConfigDict] = None) -> InstanceData: """Transform a single image's features extracted from the head into bbox results. Args: roi (Tensor): Boxes to be transformed. Has shape (num_boxes, 5). last dimension 5 arrange as (batch_index, x1, y1, x2, y2). cls_score (Tensor): Box scores, has shape (num_boxes, num_classes + 1). bbox_pred (Tensor): Box energies / deltas. has shape (num_boxes, num_classes * 4). img_meta (dict): image information. rescale (bool): If True, return boxes in original image space. Defaults to False. rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. Defaults to None Returns: :obj:`InstanceData`: Detection results of each image\ Each item usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). """ results = InstanceData() if roi.shape[0] == 0: return empty_instances([img_meta], roi.device, task_type='bbox', instance_results=[results], box_type=self.predict_box_type, use_box_type=False, num_classes=self.num_classes, score_per_cls=rcnn_test_cfg is None)[0] # some loss (Seesaw loss..) may have custom activation if self.custom_cls_channels: scores = self.loss_cls.get_activation(cls_score) else: scores = F.softmax( cls_score, dim=-1) if cls_score is not None else None img_shape = img_meta['img_shape'] num_rois = roi.size(0) # bbox_pred would be None in some detector when with_reg is False, # e.g. Grid R-CNN. if bbox_pred is not None: num_classes = 1 if self.reg_class_agnostic else self.num_classes roi = roi.repeat_interleave(num_classes, dim=0) bbox_pred = bbox_pred.view(-1, self.bbox_coder.encode_size) bboxes = self.bbox_coder.decode( roi[..., 1:], bbox_pred, max_shape=img_shape) else: bboxes = roi[:, 1:].clone() if img_shape is not None and bboxes.size(-1) == 4: bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1]) bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0]) if rescale and bboxes.size(0) > 0: assert img_meta.get('scale_factor') is not None scale_factor = [1 / s for s in img_meta['scale_factor']] bboxes = scale_boxes(bboxes, scale_factor) # Get the inside tensor when `bboxes` is a box type bboxes = get_box_tensor(bboxes) box_dim = bboxes.size(-1) bboxes = bboxes.view(num_rois, -1) if rcnn_test_cfg is None: # This means that it is aug test. # It needs to return the raw results without nms. results.bboxes = bboxes results.scores = scores else: det_bboxes, det_labels = multiclass_nms( bboxes, scores, rcnn_test_cfg.score_thr, rcnn_test_cfg.nms, rcnn_test_cfg.max_per_img, box_dim=box_dim) results.bboxes = det_bboxes[:, :-1] results.scores = det_bboxes[:, -1] results.labels = det_labels return results def refine_bboxes(self, sampling_results: Union[List[SamplingResult], InstanceList], bbox_results: dict, batch_img_metas: List[dict]) -> InstanceList: """Refine bboxes during training. Args: sampling_results (List[:obj:`SamplingResult`] or List[:obj:`InstanceData`]): Sampling results. :obj:`SamplingResult` is the real sampling results calculate from bbox_head, while :obj:`InstanceData` is fake sampling results, e.g., in Sparse R-CNN or QueryInst, etc. bbox_results (dict): Usually is a dictionary with keys: - `cls_score` (Tensor): Classification scores. - `bbox_pred` (Tensor): Box energies / deltas. - `rois` (Tensor): RoIs with the shape (n, 5) where the first column indicates batch id of each RoI. - `bbox_targets` (tuple): Ground truth for proposals in a single image. Containing the following list of Tensors: (labels, label_weights, bbox_targets, bbox_weights) batch_img_metas (List[dict]): List of image information. Returns: list[:obj:`InstanceData`]: Refined bboxes of each image. Example: >>> # xdoctest: +REQUIRES(module:kwarray) >>> import numpy as np >>> from mmdet.models.task_modules.samplers. ... sampling_result import random_boxes >>> from mmdet.models.task_modules.samplers import SamplingResult >>> self = BBoxHead(reg_class_agnostic=True) >>> n_roi = 2 >>> n_img = 4 >>> scale = 512 >>> rng = np.random.RandomState(0) ... batch_img_metas = [{'img_shape': (scale, scale)} >>> for _ in range(n_img)] >>> sampling_results = [SamplingResult.random(rng=10) ... for _ in range(n_img)] >>> # Create rois in the expected format >>> roi_boxes = random_boxes(n_roi, scale=scale, rng=rng) >>> img_ids = torch.randint(0, n_img, (n_roi,)) >>> img_ids = img_ids.float() >>> rois = torch.cat([img_ids[:, None], roi_boxes], dim=1) >>> # Create other args >>> labels = torch.randint(0, 81, (scale,)).long() >>> bbox_preds = random_boxes(n_roi, scale=scale, rng=rng) >>> cls_score = torch.randn((scale, 81)) ... # For each image, pretend random positive boxes are gts >>> bbox_targets = (labels, None, None, None) ... bbox_results = dict(rois=rois, bbox_pred=bbox_preds, ... cls_score=cls_score, ... bbox_targets=bbox_targets) >>> bboxes_list = self.refine_bboxes(sampling_results, ... bbox_results, ... batch_img_metas) >>> print(bboxes_list) """ pos_is_gts = [res.pos_is_gt for res in sampling_results] # bbox_targets is a tuple labels = bbox_results['bbox_targets'][0] cls_scores = bbox_results['cls_score'] rois = bbox_results['rois'] bbox_preds = bbox_results['bbox_pred'] if self.custom_activation: # TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead cls_scores = self.loss_cls.get_activation(cls_scores) if cls_scores.numel() == 0: return None if cls_scores.shape[-1] == self.num_classes + 1: # remove background class cls_scores = cls_scores[:, :-1] elif cls_scores.shape[-1] != self.num_classes: raise ValueError('The last dim of `cls_scores` should equal to ' '`num_classes` or `num_classes + 1`,' f'but got {cls_scores.shape[-1]}.') labels = torch.where(labels == self.num_classes, cls_scores.argmax(1), labels) img_ids = rois[:, 0].long().unique(sorted=True) assert img_ids.numel() <= len(batch_img_metas) results_list = [] for i in range(len(batch_img_metas)): inds = torch.nonzero( rois[:, 0] == i, as_tuple=False).squeeze(dim=1) num_rois = inds.numel() bboxes_ = rois[inds, 1:] label_ = labels[inds] bbox_pred_ = bbox_preds[inds] img_meta_ = batch_img_metas[i] pos_is_gts_ = pos_is_gts[i] bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_, img_meta_) # filter gt bboxes pos_keep = 1 - pos_is_gts_ keep_inds = pos_is_gts_.new_ones(num_rois) keep_inds[:len(pos_is_gts_)] = pos_keep results = InstanceData(bboxes=bboxes[keep_inds.type(torch.bool)]) results_list.append(results) return results_list def regress_by_class(self, priors: Tensor, label: Tensor, bbox_pred: Tensor, img_meta: dict) -> Tensor: """Regress the bbox for the predicted class. Used in Cascade R-CNN. Args: priors (Tensor): Priors from `rpn_head` or last stage `bbox_head`, has shape (num_proposals, 4). label (Tensor): Only used when `self.reg_class_agnostic` is False, has shape (num_proposals, ). bbox_pred (Tensor): Regression prediction of current stage `bbox_head`. When `self.reg_class_agnostic` is False, it has shape (n, num_classes * 4), otherwise it has shape (n, 4). img_meta (dict): Image meta info. Returns: Tensor: Regressed bboxes, the same shape as input rois. """ reg_dim = self.bbox_coder.encode_size if not self.reg_class_agnostic: label = label * reg_dim inds = torch.stack([label + i for i in range(reg_dim)], 1) bbox_pred = torch.gather(bbox_pred, 1, inds) assert bbox_pred.size()[1] == reg_dim max_shape = img_meta['img_shape'] regressed_bboxes = self.bbox_coder.decode( priors, bbox_pred, max_shape=max_shape) return regressed_bboxes