123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List, Optional, Tuple, Union
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from mmengine.config import ConfigDict
- from mmengine.model import BaseModule
- from mmengine.structures import InstanceData
- from torch import Tensor
- from torch.nn.modules.utils import _pair
- from mmdet.models.layers import multiclass_nms
- from mmdet.models.losses import accuracy
- from mmdet.models.task_modules.samplers import SamplingResult
- from mmdet.models.utils import empty_instances, multi_apply
- from mmdet.registry import MODELS, TASK_UTILS
- from mmdet.structures.bbox import get_box_tensor, scale_boxes
- from mmdet.utils import ConfigType, InstanceList, OptMultiConfig
- @MODELS.register_module()
- class BBoxHead(BaseModule):
- """Simplest RoI head, with only two fc layers for classification and
- regression respectively."""
- def __init__(self,
- with_avg_pool: bool = False,
- with_cls: bool = True,
- with_reg: bool = True,
- roi_feat_size: int = 7,
- in_channels: int = 256,
- num_classes: int = 80,
- bbox_coder: ConfigType = dict(
- type='DeltaXYWHBBoxCoder',
- clip_border=True,
- target_means=[0., 0., 0., 0.],
- target_stds=[0.1, 0.1, 0.2, 0.2]),
- predict_box_type: str = 'hbox',
- reg_class_agnostic: bool = False,
- reg_decoded_bbox: bool = False,
- reg_predictor_cfg: ConfigType = dict(type='Linear'),
- cls_predictor_cfg: ConfigType = dict(type='Linear'),
- loss_cls: ConfigType = dict(
- type='CrossEntropyLoss',
- use_sigmoid=False,
- loss_weight=1.0),
- loss_bbox: ConfigType = dict(
- type='SmoothL1Loss', beta=1.0, loss_weight=1.0),
- init_cfg: OptMultiConfig = None) -> None:
- super().__init__(init_cfg=init_cfg)
- assert with_cls or with_reg
- self.with_avg_pool = with_avg_pool
- self.with_cls = with_cls
- self.with_reg = with_reg
- self.roi_feat_size = _pair(roi_feat_size)
- self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1]
- self.in_channels = in_channels
- self.num_classes = num_classes
- self.predict_box_type = predict_box_type
- self.reg_class_agnostic = reg_class_agnostic
- self.reg_decoded_bbox = reg_decoded_bbox
- self.reg_predictor_cfg = reg_predictor_cfg
- self.cls_predictor_cfg = cls_predictor_cfg
- self.bbox_coder = TASK_UTILS.build(bbox_coder)
- self.loss_cls = MODELS.build(loss_cls)
- self.loss_bbox = MODELS.build(loss_bbox)
- in_channels = self.in_channels
- if self.with_avg_pool:
- self.avg_pool = nn.AvgPool2d(self.roi_feat_size)
- else:
- in_channels *= self.roi_feat_area
- if self.with_cls:
- # need to add background class
- if self.custom_cls_channels:
- cls_channels = self.loss_cls.get_cls_channels(self.num_classes)
- else:
- cls_channels = num_classes + 1
- cls_predictor_cfg_ = self.cls_predictor_cfg.copy()
- cls_predictor_cfg_.update(
- in_features=in_channels, out_features=cls_channels)
- self.fc_cls = MODELS.build(cls_predictor_cfg_)
- if self.with_reg:
- box_dim = self.bbox_coder.encode_size
- out_dim_reg = box_dim if reg_class_agnostic else \
- box_dim * num_classes
- reg_predictor_cfg_ = self.reg_predictor_cfg.copy()
- if isinstance(reg_predictor_cfg_, (dict, ConfigDict)):
- reg_predictor_cfg_.update(
- in_features=in_channels, out_features=out_dim_reg)
- self.fc_reg = MODELS.build(reg_predictor_cfg_)
- self.debug_imgs = None
- if init_cfg is None:
- self.init_cfg = []
- if self.with_cls:
- self.init_cfg += [
- dict(
- type='Normal', std=0.01, override=dict(name='fc_cls'))
- ]
- if self.with_reg:
- self.init_cfg += [
- dict(
- type='Normal', std=0.001, override=dict(name='fc_reg'))
- ]
- # TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead
- @property
- def custom_cls_channels(self) -> bool:
- """get custom_cls_channels from loss_cls."""
- return getattr(self.loss_cls, 'custom_cls_channels', False)
- # TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead
- @property
- def custom_activation(self) -> bool:
- """get custom_activation from loss_cls."""
- return getattr(self.loss_cls, 'custom_activation', False)
- # TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead
- @property
- def custom_accuracy(self) -> bool:
- """get custom_accuracy from loss_cls."""
- return getattr(self.loss_cls, 'custom_accuracy', False)
- def forward(self, x: Tuple[Tensor]) -> tuple:
- """Forward features from the upstream network.
- Args:
- x (tuple[Tensor]): Features from the upstream network, each is
- a 4D-tensor.
- Returns:
- tuple: A tuple of classification scores and bbox prediction.
- - cls_score (Tensor): Classification scores for all
- scale levels, each is a 4D-tensor, the channels number
- is num_base_priors * num_classes.
- - bbox_pred (Tensor): Box energies / deltas for all
- scale levels, each is a 4D-tensor, the channels number
- is num_base_priors * 4.
- """
- if self.with_avg_pool:
- if x.numel() > 0:
- x = self.avg_pool(x)
- x = x.view(x.size(0), -1)
- else:
- # avg_pool does not support empty tensor,
- # so use torch.mean instead it
- x = torch.mean(x, dim=(-1, -2))
- cls_score = self.fc_cls(x) if self.with_cls else None
- bbox_pred = self.fc_reg(x) if self.with_reg else None
- return cls_score, bbox_pred
- def _get_targets_single(self, pos_priors: Tensor, neg_priors: Tensor,
- pos_gt_bboxes: Tensor, pos_gt_labels: Tensor,
- cfg: ConfigDict) -> tuple:
- """Calculate the ground truth for proposals in the single image
- according to the sampling results.
- Args:
- pos_priors (Tensor): Contains all the positive boxes,
- has shape (num_pos, 4), the last dimension 4
- represents [tl_x, tl_y, br_x, br_y].
- neg_priors (Tensor): Contains all the negative boxes,
- has shape (num_neg, 4), the last dimension 4
- represents [tl_x, tl_y, br_x, br_y].
- pos_gt_bboxes (Tensor): Contains gt_boxes for
- all positive samples, has shape (num_pos, 4),
- the last dimension 4
- represents [tl_x, tl_y, br_x, br_y].
- pos_gt_labels (Tensor): Contains gt_labels for
- all positive samples, has shape (num_pos, ).
- cfg (obj:`ConfigDict`): `train_cfg` of R-CNN.
- Returns:
- Tuple[Tensor]: Ground truth for proposals
- in a single image. Containing the following Tensors:
- - labels(Tensor): Gt_labels for all proposals, has
- shape (num_proposals,).
- - label_weights(Tensor): Labels_weights for all
- proposals, has shape (num_proposals,).
- - bbox_targets(Tensor):Regression target for all
- proposals, has shape (num_proposals, 4), the
- last dimension 4 represents [tl_x, tl_y, br_x, br_y].
- - bbox_weights(Tensor):Regression weights for all
- proposals, has shape (num_proposals, 4).
- """
- num_pos = pos_priors.size(0)
- num_neg = neg_priors.size(0)
- num_samples = num_pos + num_neg
- # original implementation uses new_zeros since BG are set to be 0
- # now use empty & fill because BG cat_id = num_classes,
- # FG cat_id = [0, num_classes-1]
- labels = pos_priors.new_full((num_samples, ),
- self.num_classes,
- dtype=torch.long)
- reg_dim = pos_gt_bboxes.size(-1) if self.reg_decoded_bbox \
- else self.bbox_coder.encode_size
- label_weights = pos_priors.new_zeros(num_samples)
- bbox_targets = pos_priors.new_zeros(num_samples, reg_dim)
- bbox_weights = pos_priors.new_zeros(num_samples, reg_dim)
- if num_pos > 0:
- labels[:num_pos] = pos_gt_labels
- pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight
- label_weights[:num_pos] = pos_weight
- if not self.reg_decoded_bbox:
- pos_bbox_targets = self.bbox_coder.encode(
- pos_priors, pos_gt_bboxes)
- else:
- # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
- # is applied directly on the decoded bounding boxes, both
- # the predicted boxes and regression targets should be with
- # absolute coordinate format.
- pos_bbox_targets = get_box_tensor(pos_gt_bboxes)
- bbox_targets[:num_pos, :] = pos_bbox_targets
- bbox_weights[:num_pos, :] = 1
- if num_neg > 0:
- label_weights[-num_neg:] = 1.0
- return labels, label_weights, bbox_targets, bbox_weights
- def get_targets(self,
- sampling_results: List[SamplingResult],
- rcnn_train_cfg: ConfigDict,
- concat: bool = True) -> tuple:
- """Calculate the ground truth for all samples in a batch according to
- the sampling_results.
- Almost the same as the implementation in bbox_head, we passed
- additional parameters pos_inds_list and neg_inds_list to
- `_get_targets_single` function.
- Args:
- sampling_results (List[obj:SamplingResult]): Assign results of
- all images in a batch after sampling.
- rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
- concat (bool): Whether to concatenate the results of all
- the images in a single batch.
- Returns:
- Tuple[Tensor]: Ground truth for proposals in a single image.
- Containing the following list of Tensors:
- - labels (list[Tensor],Tensor): Gt_labels for all
- proposals in a batch, each tensor in list has
- shape (num_proposals,) when `concat=False`, otherwise
- just a single tensor has shape (num_all_proposals,).
- - label_weights (list[Tensor]): Labels_weights for
- all proposals in a batch, each tensor in list has
- shape (num_proposals,) when `concat=False`, otherwise
- just a single tensor has shape (num_all_proposals,).
- - bbox_targets (list[Tensor],Tensor): Regression target
- for all proposals in a batch, each tensor in list
- has shape (num_proposals, 4) when `concat=False`,
- otherwise just a single tensor has shape
- (num_all_proposals, 4), the last dimension 4 represents
- [tl_x, tl_y, br_x, br_y].
- - bbox_weights (list[tensor],Tensor): Regression weights for
- all proposals in a batch, each tensor in list has shape
- (num_proposals, 4) when `concat=False`, otherwise just a
- single tensor has shape (num_all_proposals, 4).
- """
- pos_priors_list = [res.pos_priors for res in sampling_results]
- neg_priors_list = [res.neg_priors for res in sampling_results]
- pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results]
- pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]
- labels, label_weights, bbox_targets, bbox_weights = multi_apply(
- self._get_targets_single,
- pos_priors_list,
- neg_priors_list,
- pos_gt_bboxes_list,
- pos_gt_labels_list,
- cfg=rcnn_train_cfg)
- if concat:
- labels = torch.cat(labels, 0)
- label_weights = torch.cat(label_weights, 0)
- bbox_targets = torch.cat(bbox_targets, 0)
- bbox_weights = torch.cat(bbox_weights, 0)
- return labels, label_weights, bbox_targets, bbox_weights
- def loss_and_target(self,
- cls_score: Tensor,
- bbox_pred: Tensor,
- rois: Tensor,
- sampling_results: List[SamplingResult],
- rcnn_train_cfg: ConfigDict,
- concat: bool = True,
- reduction_override: Optional[str] = None) -> dict:
- """Calculate the loss based on the features extracted by the bbox head.
- Args:
- cls_score (Tensor): Classification prediction
- results of all class, has shape
- (batch_size * num_proposals_single_image, num_classes)
- bbox_pred (Tensor): Regression prediction results,
- has shape
- (batch_size * num_proposals_single_image, 4), the last
- dimension 4 represents [tl_x, tl_y, br_x, br_y].
- rois (Tensor): RoIs with the shape
- (batch_size * num_proposals_single_image, 5) where the first
- column indicates batch id of each RoI.
- sampling_results (List[obj:SamplingResult]): Assign results of
- all images in a batch after sampling.
- rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
- concat (bool): Whether to concatenate the results of all
- the images in a single batch. Defaults to True.
- reduction_override (str, optional): The reduction
- method used to override the original reduction
- method of the loss. Options are "none",
- "mean" and "sum". Defaults to None,
- Returns:
- dict: A dictionary of loss and targets components.
- The targets are only used for cascade rcnn.
- """
- cls_reg_targets = self.get_targets(
- sampling_results, rcnn_train_cfg, concat=concat)
- losses = self.loss(
- cls_score,
- bbox_pred,
- rois,
- *cls_reg_targets,
- reduction_override=reduction_override)
- # cls_reg_targets is only for cascade rcnn
- return dict(loss_bbox=losses, bbox_targets=cls_reg_targets)
- def loss(self,
- cls_score: Tensor,
- bbox_pred: Tensor,
- rois: Tensor,
- labels: Tensor,
- label_weights: Tensor,
- bbox_targets: Tensor,
- bbox_weights: Tensor,
- reduction_override: Optional[str] = None) -> dict:
- """Calculate the loss based on the network predictions and targets.
- Args:
- cls_score (Tensor): Classification prediction
- results of all class, has shape
- (batch_size * num_proposals_single_image, num_classes)
- bbox_pred (Tensor): Regression prediction results,
- has shape
- (batch_size * num_proposals_single_image, 4), the last
- dimension 4 represents [tl_x, tl_y, br_x, br_y].
- rois (Tensor): RoIs with the shape
- (batch_size * num_proposals_single_image, 5) where the first
- column indicates batch id of each RoI.
- labels (Tensor): Gt_labels for all proposals in a batch, has
- shape (batch_size * num_proposals_single_image, ).
- label_weights (Tensor): Labels_weights for all proposals in a
- batch, has shape (batch_size * num_proposals_single_image, ).
- bbox_targets (Tensor): Regression target for all proposals in a
- batch, has shape (batch_size * num_proposals_single_image, 4),
- the last dimension 4 represents [tl_x, tl_y, br_x, br_y].
- bbox_weights (Tensor): Regression weights for all proposals in a
- batch, has shape (batch_size * num_proposals_single_image, 4).
- reduction_override (str, optional): The reduction
- method used to override the original reduction
- method of the loss. Options are "none",
- "mean" and "sum". Defaults to None,
- Returns:
- dict: A dictionary of loss.
- """
- losses = dict()
- if cls_score is not None:
- avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
- if cls_score.numel() > 0:
- loss_cls_ = self.loss_cls(
- cls_score,
- labels,
- label_weights,
- avg_factor=avg_factor,
- reduction_override=reduction_override)
- if isinstance(loss_cls_, dict):
- losses.update(loss_cls_)
- else:
- losses['loss_cls'] = loss_cls_
- if self.custom_activation:
- acc_ = self.loss_cls.get_accuracy(cls_score, labels)
- losses.update(acc_)
- else:
- losses['acc'] = accuracy(cls_score, labels)
- if bbox_pred is not None:
- bg_class_ind = self.num_classes
- # 0~self.num_classes-1 are FG, self.num_classes is BG
- pos_inds = (labels >= 0) & (labels < bg_class_ind)
- # do not perform bounding box regression for BG anymore.
- if pos_inds.any():
- if self.reg_decoded_bbox:
- # When the regression loss (e.g. `IouLoss`,
- # `GIouLoss`, `DIouLoss`) is applied directly on
- # the decoded bounding boxes, it decodes the
- # already encoded coordinates to absolute format.
- bbox_pred = self.bbox_coder.decode(rois[:, 1:], bbox_pred)
- bbox_pred = get_box_tensor(bbox_pred)
- if self.reg_class_agnostic:
- pos_bbox_pred = bbox_pred.view(
- bbox_pred.size(0), -1)[pos_inds.type(torch.bool)]
- else:
- pos_bbox_pred = bbox_pred.view(
- bbox_pred.size(0), self.num_classes,
- -1)[pos_inds.type(torch.bool),
- labels[pos_inds.type(torch.bool)]]
- losses['loss_bbox'] = self.loss_bbox(
- pos_bbox_pred,
- bbox_targets[pos_inds.type(torch.bool)],
- bbox_weights[pos_inds.type(torch.bool)],
- avg_factor=bbox_targets.size(0),
- reduction_override=reduction_override)
- else:
- losses['loss_bbox'] = bbox_pred[pos_inds].sum()
- return losses
- def predict_by_feat(self,
- rois: Tuple[Tensor],
- cls_scores: Tuple[Tensor],
- bbox_preds: Tuple[Tensor],
- batch_img_metas: List[dict],
- rcnn_test_cfg: Optional[ConfigDict] = None,
- rescale: bool = False) -> InstanceList:
- """Transform a batch of output features extracted from the head into
- bbox results.
- Args:
- rois (tuple[Tensor]): Tuple of boxes to be transformed.
- Each has shape (num_boxes, 5). last dimension 5 arrange as
- (batch_index, x1, y1, x2, y2).
- cls_scores (tuple[Tensor]): Tuple of box scores, each has shape
- (num_boxes, num_classes + 1).
- bbox_preds (tuple[Tensor]): Tuple of box energies / deltas, each
- has shape (num_boxes, num_classes * 4).
- batch_img_metas (list[dict]): List of image information.
- rcnn_test_cfg (obj:`ConfigDict`, optional): `test_cfg` of R-CNN.
- Defaults to None.
- rescale (bool): If True, return boxes in original image space.
- Defaults to False.
- Returns:
- list[:obj:`InstanceData`]: Instance segmentation
- results of each image after the post process.
- Each item usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- """
- assert len(cls_scores) == len(bbox_preds)
- result_list = []
- for img_id in range(len(batch_img_metas)):
- img_meta = batch_img_metas[img_id]
- results = self._predict_by_feat_single(
- roi=rois[img_id],
- cls_score=cls_scores[img_id],
- bbox_pred=bbox_preds[img_id],
- img_meta=img_meta,
- rescale=rescale,
- rcnn_test_cfg=rcnn_test_cfg)
- result_list.append(results)
- return result_list
- def _predict_by_feat_single(
- self,
- roi: Tensor,
- cls_score: Tensor,
- bbox_pred: Tensor,
- img_meta: dict,
- rescale: bool = False,
- rcnn_test_cfg: Optional[ConfigDict] = None) -> InstanceData:
- """Transform a single image's features extracted from the head into
- bbox results.
- Args:
- roi (Tensor): Boxes to be transformed. Has shape (num_boxes, 5).
- last dimension 5 arrange as (batch_index, x1, y1, x2, y2).
- cls_score (Tensor): Box scores, has shape
- (num_boxes, num_classes + 1).
- bbox_pred (Tensor): Box energies / deltas.
- has shape (num_boxes, num_classes * 4).
- img_meta (dict): image information.
- rescale (bool): If True, return boxes in original image space.
- Defaults to False.
- rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head.
- Defaults to None
- Returns:
- :obj:`InstanceData`: Detection results of each image\
- Each item usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- """
- results = InstanceData()
- if roi.shape[0] == 0:
- return empty_instances([img_meta],
- roi.device,
- task_type='bbox',
- instance_results=[results],
- box_type=self.predict_box_type,
- use_box_type=False,
- num_classes=self.num_classes,
- score_per_cls=rcnn_test_cfg is None)[0]
- # some loss (Seesaw loss..) may have custom activation
- if self.custom_cls_channels:
- scores = self.loss_cls.get_activation(cls_score)
- else:
- scores = F.softmax(
- cls_score, dim=-1) if cls_score is not None else None
- img_shape = img_meta['img_shape']
- num_rois = roi.size(0)
- # bbox_pred would be None in some detector when with_reg is False,
- # e.g. Grid R-CNN.
- if bbox_pred is not None:
- num_classes = 1 if self.reg_class_agnostic else self.num_classes
- roi = roi.repeat_interleave(num_classes, dim=0)
- bbox_pred = bbox_pred.view(-1, self.bbox_coder.encode_size)
- bboxes = self.bbox_coder.decode(
- roi[..., 1:], bbox_pred, max_shape=img_shape)
- else:
- bboxes = roi[:, 1:].clone()
- if img_shape is not None and bboxes.size(-1) == 4:
- bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1])
- bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0])
- if rescale and bboxes.size(0) > 0:
- assert img_meta.get('scale_factor') is not None
- scale_factor = [1 / s for s in img_meta['scale_factor']]
- bboxes = scale_boxes(bboxes, scale_factor)
- # Get the inside tensor when `bboxes` is a box type
- bboxes = get_box_tensor(bboxes)
- box_dim = bboxes.size(-1)
- bboxes = bboxes.view(num_rois, -1)
- if rcnn_test_cfg is None:
- # This means that it is aug test.
- # It needs to return the raw results without nms.
- results.bboxes = bboxes
- results.scores = scores
- else:
- det_bboxes, det_labels = multiclass_nms(
- bboxes,
- scores,
- rcnn_test_cfg.score_thr,
- rcnn_test_cfg.nms,
- rcnn_test_cfg.max_per_img,
- box_dim=box_dim)
- results.bboxes = det_bboxes[:, :-1]
- results.scores = det_bboxes[:, -1]
- results.labels = det_labels
- return results
- def refine_bboxes(self, sampling_results: Union[List[SamplingResult],
- InstanceList],
- bbox_results: dict,
- batch_img_metas: List[dict]) -> InstanceList:
- """Refine bboxes during training.
- Args:
- sampling_results (List[:obj:`SamplingResult`] or
- List[:obj:`InstanceData`]): Sampling results.
- :obj:`SamplingResult` is the real sampling results
- calculate from bbox_head, while :obj:`InstanceData` is
- fake sampling results, e.g., in Sparse R-CNN or QueryInst, etc.
- bbox_results (dict): Usually is a dictionary with keys:
- - `cls_score` (Tensor): Classification scores.
- - `bbox_pred` (Tensor): Box energies / deltas.
- - `rois` (Tensor): RoIs with the shape (n, 5) where the first
- column indicates batch id of each RoI.
- - `bbox_targets` (tuple): Ground truth for proposals in a
- single image. Containing the following list of Tensors:
- (labels, label_weights, bbox_targets, bbox_weights)
- batch_img_metas (List[dict]): List of image information.
- Returns:
- list[:obj:`InstanceData`]: Refined bboxes of each image.
- Example:
- >>> # xdoctest: +REQUIRES(module:kwarray)
- >>> import numpy as np
- >>> from mmdet.models.task_modules.samplers.
- ... sampling_result import random_boxes
- >>> from mmdet.models.task_modules.samplers import SamplingResult
- >>> self = BBoxHead(reg_class_agnostic=True)
- >>> n_roi = 2
- >>> n_img = 4
- >>> scale = 512
- >>> rng = np.random.RandomState(0)
- ... batch_img_metas = [{'img_shape': (scale, scale)}
- >>> for _ in range(n_img)]
- >>> sampling_results = [SamplingResult.random(rng=10)
- ... for _ in range(n_img)]
- >>> # Create rois in the expected format
- >>> roi_boxes = random_boxes(n_roi, scale=scale, rng=rng)
- >>> img_ids = torch.randint(0, n_img, (n_roi,))
- >>> img_ids = img_ids.float()
- >>> rois = torch.cat([img_ids[:, None], roi_boxes], dim=1)
- >>> # Create other args
- >>> labels = torch.randint(0, 81, (scale,)).long()
- >>> bbox_preds = random_boxes(n_roi, scale=scale, rng=rng)
- >>> cls_score = torch.randn((scale, 81))
- ... # For each image, pretend random positive boxes are gts
- >>> bbox_targets = (labels, None, None, None)
- ... bbox_results = dict(rois=rois, bbox_pred=bbox_preds,
- ... cls_score=cls_score,
- ... bbox_targets=bbox_targets)
- >>> bboxes_list = self.refine_bboxes(sampling_results,
- ... bbox_results,
- ... batch_img_metas)
- >>> print(bboxes_list)
- """
- pos_is_gts = [res.pos_is_gt for res in sampling_results]
- # bbox_targets is a tuple
- labels = bbox_results['bbox_targets'][0]
- cls_scores = bbox_results['cls_score']
- rois = bbox_results['rois']
- bbox_preds = bbox_results['bbox_pred']
- if self.custom_activation:
- # TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead
- cls_scores = self.loss_cls.get_activation(cls_scores)
- if cls_scores.numel() == 0:
- return None
- if cls_scores.shape[-1] == self.num_classes + 1:
- # remove background class
- cls_scores = cls_scores[:, :-1]
- elif cls_scores.shape[-1] != self.num_classes:
- raise ValueError('The last dim of `cls_scores` should equal to '
- '`num_classes` or `num_classes + 1`,'
- f'but got {cls_scores.shape[-1]}.')
- labels = torch.where(labels == self.num_classes, cls_scores.argmax(1),
- labels)
- img_ids = rois[:, 0].long().unique(sorted=True)
- assert img_ids.numel() <= len(batch_img_metas)
- results_list = []
- for i in range(len(batch_img_metas)):
- inds = torch.nonzero(
- rois[:, 0] == i, as_tuple=False).squeeze(dim=1)
- num_rois = inds.numel()
- bboxes_ = rois[inds, 1:]
- label_ = labels[inds]
- bbox_pred_ = bbox_preds[inds]
- img_meta_ = batch_img_metas[i]
- pos_is_gts_ = pos_is_gts[i]
- bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_,
- img_meta_)
- # filter gt bboxes
- pos_keep = 1 - pos_is_gts_
- keep_inds = pos_is_gts_.new_ones(num_rois)
- keep_inds[:len(pos_is_gts_)] = pos_keep
- results = InstanceData(bboxes=bboxes[keep_inds.type(torch.bool)])
- results_list.append(results)
- return results_list
- def regress_by_class(self, priors: Tensor, label: Tensor,
- bbox_pred: Tensor, img_meta: dict) -> Tensor:
- """Regress the bbox for the predicted class. Used in Cascade R-CNN.
- Args:
- priors (Tensor): Priors from `rpn_head` or last stage
- `bbox_head`, has shape (num_proposals, 4).
- label (Tensor): Only used when `self.reg_class_agnostic`
- is False, has shape (num_proposals, ).
- bbox_pred (Tensor): Regression prediction of
- current stage `bbox_head`. When `self.reg_class_agnostic`
- is False, it has shape (n, num_classes * 4), otherwise
- it has shape (n, 4).
- img_meta (dict): Image meta info.
- Returns:
- Tensor: Regressed bboxes, the same shape as input rois.
- """
- reg_dim = self.bbox_coder.encode_size
- if not self.reg_class_agnostic:
- label = label * reg_dim
- inds = torch.stack([label + i for i in range(reg_dim)], 1)
- bbox_pred = torch.gather(bbox_pred, 1, inds)
- assert bbox_pred.size()[1] == reg_dim
- max_shape = img_meta['img_shape']
- regressed_bboxes = self.bbox_coder.decode(
- priors, bbox_pred, max_shape=max_shape)
- return regressed_bboxes
|