123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List, Optional, Tuple
- import numpy as np
- import torch
- from mmengine.structures import InstanceData
- from torch import Tensor
- from mmdet.registry import MODELS
- from mmdet.structures.bbox import bbox_overlaps
- from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
- OptInstanceList)
- from ..layers import multiclass_nms
- from ..utils import levels_to_images, multi_apply
- from . import ATSSHead
- EPS = 1e-12
- try:
- import sklearn.mixture as skm
- except ImportError:
- skm = None
- @MODELS.register_module()
- class PAAHead(ATSSHead):
- """Head of PAAAssignment: Probabilistic Anchor Assignment with IoU
- Prediction for Object Detection.
- Code is modified from the `official github repo
- <https://github.com/kkhoot/PAA/blob/master/paa_core
- /modeling/rpn/paa/loss.py>`_.
- More details can be found in the `paper
- <https://arxiv.org/abs/2007.08103>`_ .
- Args:
- topk (int): Select topk samples with smallest loss in
- each level.
- score_voting (bool): Whether to use score voting in post-process.
- covariance_type : String describing the type of covariance parameters
- to be used in :class:`sklearn.mixture.GaussianMixture`.
- It must be one of:
- - 'full': each component has its own general covariance matrix
- - 'tied': all components share the same general covariance matrix
- - 'diag': each component has its own diagonal covariance matrix
- - 'spherical': each component has its own single variance
- Default: 'diag'. From 'full' to 'spherical', the gmm fitting
- process is faster yet the performance could be influenced. For most
- cases, 'diag' should be a good choice.
- """
- def __init__(self,
- *args,
- topk: int = 9,
- score_voting: bool = True,
- covariance_type: str = 'diag',
- **kwargs):
- # topk used in paa reassign process
- self.topk = topk
- self.with_score_voting = score_voting
- self.covariance_type = covariance_type
- super().__init__(*args, **kwargs)
- def loss_by_feat(
- self,
- cls_scores: List[Tensor],
- bbox_preds: List[Tensor],
- iou_preds: List[Tensor],
- batch_gt_instances: InstanceList,
- batch_img_metas: List[dict],
- batch_gt_instances_ignore: OptInstanceList = None) -> dict:
- """Calculate the loss based on the features extracted by the detection
- head.
- Args:
- cls_scores (list[Tensor]): Box scores for each scale level
- Has shape (N, num_anchors * num_classes, H, W)
- bbox_preds (list[Tensor]): Box energies / deltas for each scale
- level with shape (N, num_anchors * 4, H, W)
- iou_preds (list[Tensor]): iou_preds for each scale
- level with shape (N, num_anchors * 1, H, W)
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes`` and ``labels``
- attributes.
- batch_img_metas (list[dict]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
- batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
- Batch of gt_instances_ignore. It includes ``bboxes`` attribute
- data that is ignored during training and testing.
- Defaults to None.
- Returns:
- dict[str, Tensor]: A dictionary of loss gmm_assignment.
- """
- featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
- assert len(featmap_sizes) == self.prior_generator.num_levels
- device = cls_scores[0].device
- anchor_list, valid_flag_list = self.get_anchors(
- featmap_sizes, batch_img_metas, device=device)
- cls_reg_targets = self.get_targets(
- anchor_list,
- valid_flag_list,
- batch_gt_instances,
- batch_img_metas,
- batch_gt_instances_ignore=batch_gt_instances_ignore,
- )
- (labels, labels_weight, bboxes_target, bboxes_weight, pos_inds,
- pos_gt_index) = cls_reg_targets
- cls_scores = levels_to_images(cls_scores)
- cls_scores = [
- item.reshape(-1, self.cls_out_channels) for item in cls_scores
- ]
- bbox_preds = levels_to_images(bbox_preds)
- bbox_preds = [item.reshape(-1, 4) for item in bbox_preds]
- iou_preds = levels_to_images(iou_preds)
- iou_preds = [item.reshape(-1, 1) for item in iou_preds]
- pos_losses_list, = multi_apply(self.get_pos_loss, anchor_list,
- cls_scores, bbox_preds, labels,
- labels_weight, bboxes_target,
- bboxes_weight, pos_inds)
- with torch.no_grad():
- reassign_labels, reassign_label_weight, \
- reassign_bbox_weights, num_pos = multi_apply(
- self.paa_reassign,
- pos_losses_list,
- labels,
- labels_weight,
- bboxes_weight,
- pos_inds,
- pos_gt_index,
- anchor_list)
- num_pos = sum(num_pos)
- # convert all tensor list to a flatten tensor
- cls_scores = torch.cat(cls_scores, 0).view(-1, cls_scores[0].size(-1))
- bbox_preds = torch.cat(bbox_preds, 0).view(-1, bbox_preds[0].size(-1))
- iou_preds = torch.cat(iou_preds, 0).view(-1, iou_preds[0].size(-1))
- labels = torch.cat(reassign_labels, 0).view(-1)
- flatten_anchors = torch.cat(
- [torch.cat(item, 0) for item in anchor_list])
- labels_weight = torch.cat(reassign_label_weight, 0).view(-1)
- bboxes_target = torch.cat(bboxes_target,
- 0).view(-1, bboxes_target[0].size(-1))
- pos_inds_flatten = ((labels >= 0)
- &
- (labels < self.num_classes)).nonzero().reshape(-1)
- losses_cls = self.loss_cls(
- cls_scores,
- labels,
- labels_weight,
- avg_factor=max(num_pos, len(batch_img_metas))) # avoid num_pos=0
- if num_pos:
- pos_bbox_pred = self.bbox_coder.decode(
- flatten_anchors[pos_inds_flatten],
- bbox_preds[pos_inds_flatten])
- pos_bbox_target = bboxes_target[pos_inds_flatten]
- iou_target = bbox_overlaps(
- pos_bbox_pred.detach(), pos_bbox_target, is_aligned=True)
- losses_iou = self.loss_centerness(
- iou_preds[pos_inds_flatten],
- iou_target.unsqueeze(-1),
- avg_factor=num_pos)
- losses_bbox = self.loss_bbox(
- pos_bbox_pred,
- pos_bbox_target,
- iou_target.clamp(min=EPS),
- avg_factor=iou_target.sum())
- else:
- losses_iou = iou_preds.sum() * 0
- losses_bbox = bbox_preds.sum() * 0
- return dict(
- loss_cls=losses_cls, loss_bbox=losses_bbox, loss_iou=losses_iou)
- def get_pos_loss(self, anchors: List[Tensor], cls_score: Tensor,
- bbox_pred: Tensor, label: Tensor, label_weight: Tensor,
- bbox_target: dict, bbox_weight: Tensor,
- pos_inds: Tensor) -> Tensor:
- """Calculate loss of all potential positive samples obtained from first
- match process.
- Args:
- anchors (list[Tensor]): Anchors of each scale.
- cls_score (Tensor): Box scores of single image with shape
- (num_anchors, num_classes)
- bbox_pred (Tensor): Box energies / deltas of single image
- with shape (num_anchors, 4)
- label (Tensor): classification target of each anchor with
- shape (num_anchors,)
- label_weight (Tensor): Classification loss weight of each
- anchor with shape (num_anchors).
- bbox_target (dict): Regression target of each anchor with
- shape (num_anchors, 4).
- bbox_weight (Tensor): Bbox weight of each anchor with shape
- (num_anchors, 4).
- pos_inds (Tensor): Index of all positive samples got from
- first assign process.
- Returns:
- Tensor: Losses of all positive samples in single image.
- """
- if not len(pos_inds):
- return cls_score.new([]),
- anchors_all_level = torch.cat(anchors, 0)
- pos_scores = cls_score[pos_inds]
- pos_bbox_pred = bbox_pred[pos_inds]
- pos_label = label[pos_inds]
- pos_label_weight = label_weight[pos_inds]
- pos_bbox_target = bbox_target[pos_inds]
- pos_bbox_weight = bbox_weight[pos_inds]
- pos_anchors = anchors_all_level[pos_inds]
- pos_bbox_pred = self.bbox_coder.decode(pos_anchors, pos_bbox_pred)
- # to keep loss dimension
- loss_cls = self.loss_cls(
- pos_scores,
- pos_label,
- pos_label_weight,
- avg_factor=1.0,
- reduction_override='none')
- loss_bbox = self.loss_bbox(
- pos_bbox_pred,
- pos_bbox_target,
- pos_bbox_weight,
- avg_factor=1.0, # keep same loss weight before reassign
- reduction_override='none')
- loss_cls = loss_cls.sum(-1)
- pos_loss = loss_bbox + loss_cls
- return pos_loss,
- def paa_reassign(self, pos_losses: Tensor, label: Tensor,
- label_weight: Tensor, bbox_weight: Tensor,
- pos_inds: Tensor, pos_gt_inds: Tensor,
- anchors: List[Tensor]) -> tuple:
- """Fit loss to GMM distribution and separate positive, ignore, negative
- samples again with GMM model.
- Args:
- pos_losses (Tensor): Losses of all positive samples in
- single image.
- label (Tensor): classification target of each anchor with
- shape (num_anchors,)
- label_weight (Tensor): Classification loss weight of each
- anchor with shape (num_anchors).
- bbox_weight (Tensor): Bbox weight of each anchor with shape
- (num_anchors, 4).
- pos_inds (Tensor): Index of all positive samples got from
- first assign process.
- pos_gt_inds (Tensor): Gt_index of all positive samples got
- from first assign process.
- anchors (list[Tensor]): Anchors of each scale.
- Returns:
- tuple: Usually returns a tuple containing learning targets.
- - label (Tensor): classification target of each anchor after
- paa assign, with shape (num_anchors,)
- - label_weight (Tensor): Classification loss weight of each
- anchor after paa assign, with shape (num_anchors).
- - bbox_weight (Tensor): Bbox weight of each anchor with shape
- (num_anchors, 4).
- - num_pos (int): The number of positive samples after paa
- assign.
- """
- if not len(pos_inds):
- return label, label_weight, bbox_weight, 0
- label = label.clone()
- label_weight = label_weight.clone()
- bbox_weight = bbox_weight.clone()
- num_gt = pos_gt_inds.max() + 1
- num_level = len(anchors)
- num_anchors_each_level = [item.size(0) for item in anchors]
- num_anchors_each_level.insert(0, 0)
- inds_level_interval = np.cumsum(num_anchors_each_level)
- pos_level_mask = []
- for i in range(num_level):
- mask = (pos_inds >= inds_level_interval[i]) & (
- pos_inds < inds_level_interval[i + 1])
- pos_level_mask.append(mask)
- pos_inds_after_paa = [label.new_tensor([])]
- ignore_inds_after_paa = [label.new_tensor([])]
- for gt_ind in range(num_gt):
- pos_inds_gmm = []
- pos_loss_gmm = []
- gt_mask = pos_gt_inds == gt_ind
- for level in range(num_level):
- level_mask = pos_level_mask[level]
- level_gt_mask = level_mask & gt_mask
- value, topk_inds = pos_losses[level_gt_mask].topk(
- min(level_gt_mask.sum(), self.topk), largest=False)
- pos_inds_gmm.append(pos_inds[level_gt_mask][topk_inds])
- pos_loss_gmm.append(value)
- pos_inds_gmm = torch.cat(pos_inds_gmm)
- pos_loss_gmm = torch.cat(pos_loss_gmm)
- # fix gmm need at least two sample
- if len(pos_inds_gmm) < 2:
- continue
- device = pos_inds_gmm.device
- pos_loss_gmm, sort_inds = pos_loss_gmm.sort()
- pos_inds_gmm = pos_inds_gmm[sort_inds]
- pos_loss_gmm = pos_loss_gmm.view(-1, 1).cpu().numpy()
- min_loss, max_loss = pos_loss_gmm.min(), pos_loss_gmm.max()
- means_init = np.array([min_loss, max_loss]).reshape(2, 1)
- weights_init = np.array([0.5, 0.5])
- precisions_init = np.array([1.0, 1.0]).reshape(2, 1, 1) # full
- if self.covariance_type == 'spherical':
- precisions_init = precisions_init.reshape(2)
- elif self.covariance_type == 'diag':
- precisions_init = precisions_init.reshape(2, 1)
- elif self.covariance_type == 'tied':
- precisions_init = np.array([[1.0]])
- if skm is None:
- raise ImportError('Please run "pip install sklearn" '
- 'to install sklearn first.')
- gmm = skm.GaussianMixture(
- 2,
- weights_init=weights_init,
- means_init=means_init,
- precisions_init=precisions_init,
- covariance_type=self.covariance_type)
- gmm.fit(pos_loss_gmm)
- gmm_assignment = gmm.predict(pos_loss_gmm)
- scores = gmm.score_samples(pos_loss_gmm)
- gmm_assignment = torch.from_numpy(gmm_assignment).to(device)
- scores = torch.from_numpy(scores).to(device)
- pos_inds_temp, ignore_inds_temp = self.gmm_separation_scheme(
- gmm_assignment, scores, pos_inds_gmm)
- pos_inds_after_paa.append(pos_inds_temp)
- ignore_inds_after_paa.append(ignore_inds_temp)
- pos_inds_after_paa = torch.cat(pos_inds_after_paa)
- ignore_inds_after_paa = torch.cat(ignore_inds_after_paa)
- reassign_mask = (pos_inds.unsqueeze(1) != pos_inds_after_paa).all(1)
- reassign_ids = pos_inds[reassign_mask]
- label[reassign_ids] = self.num_classes
- label_weight[ignore_inds_after_paa] = 0
- bbox_weight[reassign_ids] = 0
- num_pos = len(pos_inds_after_paa)
- return label, label_weight, bbox_weight, num_pos
- def gmm_separation_scheme(self, gmm_assignment: Tensor, scores: Tensor,
- pos_inds_gmm: Tensor) -> Tuple[Tensor, Tensor]:
- """A general separation scheme for gmm model.
- It separates a GMM distribution of candidate samples into three
- parts, 0 1 and uncertain areas, and you can implement other
- separation schemes by rewriting this function.
- Args:
- gmm_assignment (Tensor): The prediction of GMM which is of shape
- (num_samples,). The 0/1 value indicates the distribution
- that each sample comes from.
- scores (Tensor): The probability of sample coming from the
- fit GMM distribution. The tensor is of shape (num_samples,).
- pos_inds_gmm (Tensor): All the indexes of samples which are used
- to fit GMM model. The tensor is of shape (num_samples,)
- Returns:
- tuple[Tensor, Tensor]: The indices of positive and ignored samples.
- - pos_inds_temp (Tensor): Indices of positive samples.
- - ignore_inds_temp (Tensor): Indices of ignore samples.
- """
- # The implementation is (c) in Fig.3 in origin paper instead of (b).
- # You can refer to issues such as
- # https://github.com/kkhoot/PAA/issues/8 and
- # https://github.com/kkhoot/PAA/issues/9.
- fgs = gmm_assignment == 0
- pos_inds_temp = fgs.new_tensor([], dtype=torch.long)
- ignore_inds_temp = fgs.new_tensor([], dtype=torch.long)
- if fgs.nonzero().numel():
- _, pos_thr_ind = scores[fgs].topk(1)
- pos_inds_temp = pos_inds_gmm[fgs][:pos_thr_ind + 1]
- ignore_inds_temp = pos_inds_gmm.new_tensor([])
- return pos_inds_temp, ignore_inds_temp
- def get_targets(self,
- anchor_list: List[List[Tensor]],
- valid_flag_list: List[List[Tensor]],
- batch_gt_instances: InstanceList,
- batch_img_metas: List[dict],
- batch_gt_instances_ignore: OptInstanceList = None,
- unmap_outputs: bool = True) -> tuple:
- """Get targets for PAA head.
- This method is almost the same as `AnchorHead.get_targets()`. We direct
- return the results from _get_targets_single instead map it to levels
- by images_to_levels function.
- Args:
- anchor_list (list[list[Tensor]]): Multi level anchors of each
- image. The outer list indicates images, and the inner list
- corresponds to feature levels of the image. Each element of
- the inner list is a tensor of shape (num_anchors, 4).
- valid_flag_list (list[list[Tensor]]): Multi level valid flags of
- each image. The outer list indicates images, and the inner list
- corresponds to feature levels of the image. Each element of
- the inner list is a tensor of shape (num_anchors, )
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes`` and ``labels``
- attributes.
- batch_img_metas (list[dict]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
- batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
- Batch of gt_instances_ignore. It includes ``bboxes`` attribute
- data that is ignored during training and testing.
- Defaults to None.
- unmap_outputs (bool): Whether to map outputs back to the original
- set of anchors. Defaults to True.
- Returns:
- tuple: Usually returns a tuple containing learning targets.
- - labels (list[Tensor]): Labels of all anchors, each with
- shape (num_anchors,).
- - label_weights (list[Tensor]): Label weights of all anchor.
- each with shape (num_anchors,).
- - bbox_targets (list[Tensor]): BBox targets of all anchors.
- each with shape (num_anchors, 4).
- - bbox_weights (list[Tensor]): BBox weights of all anchors.
- each with shape (num_anchors, 4).
- - pos_inds (list[Tensor]): Contains all index of positive
- sample in all anchor.
- - gt_inds (list[Tensor]): Contains all gt_index of positive
- sample in all anchor.
- """
- num_imgs = len(batch_img_metas)
- assert len(anchor_list) == len(valid_flag_list) == num_imgs
- concat_anchor_list = []
- concat_valid_flag_list = []
- for i in range(num_imgs):
- assert len(anchor_list[i]) == len(valid_flag_list[i])
- concat_anchor_list.append(torch.cat(anchor_list[i]))
- concat_valid_flag_list.append(torch.cat(valid_flag_list[i]))
- # compute targets for each image
- if batch_gt_instances_ignore is None:
- batch_gt_instances_ignore = [None] * num_imgs
- results = multi_apply(
- self._get_targets_single,
- concat_anchor_list,
- concat_valid_flag_list,
- batch_gt_instances,
- batch_img_metas,
- batch_gt_instances_ignore,
- unmap_outputs=unmap_outputs)
- (labels, label_weights, bbox_targets, bbox_weights, valid_pos_inds,
- valid_neg_inds, sampling_result) = results
- # Due to valid flag of anchors, we have to calculate the real pos_inds
- # in origin anchor set.
- pos_inds = []
- for i, single_labels in enumerate(labels):
- pos_mask = (0 <= single_labels) & (
- single_labels < self.num_classes)
- pos_inds.append(pos_mask.nonzero().view(-1))
- gt_inds = [item.pos_assigned_gt_inds for item in sampling_result]
- return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
- gt_inds)
- def _get_targets_single(self,
- flat_anchors: Tensor,
- valid_flags: Tensor,
- gt_instances: InstanceData,
- img_meta: dict,
- gt_instances_ignore: Optional[InstanceData] = None,
- unmap_outputs: bool = True) -> tuple:
- """Compute regression and classification targets for anchors in a
- single image.
- This method is same as `AnchorHead._get_targets_single()`.
- """
- assert unmap_outputs, 'We must map outputs back to the original' \
- 'set of anchors in PAAhead'
- return super(ATSSHead, self)._get_targets_single(
- flat_anchors,
- valid_flags,
- gt_instances,
- img_meta,
- gt_instances_ignore,
- unmap_outputs=True)
- def predict_by_feat(self,
- cls_scores: List[Tensor],
- bbox_preds: List[Tensor],
- score_factors: Optional[List[Tensor]] = None,
- batch_img_metas: Optional[List[dict]] = None,
- cfg: OptConfigType = None,
- rescale: bool = False,
- with_nms: bool = True) -> InstanceList:
- """Transform a batch of output features extracted from the head into
- bbox results.
- This method is same as `BaseDenseHead.get_results()`.
- """
- assert with_nms, 'PAA only supports "with_nms=True" now and it ' \
- 'means PAAHead does not support ' \
- 'test-time augmentation'
- return super().predict_by_feat(
- cls_scores=cls_scores,
- bbox_preds=bbox_preds,
- score_factors=score_factors,
- batch_img_metas=batch_img_metas,
- cfg=cfg,
- rescale=rescale,
- with_nms=with_nms)
- def _predict_by_feat_single(self,
- cls_score_list: List[Tensor],
- bbox_pred_list: List[Tensor],
- score_factor_list: List[Tensor],
- mlvl_priors: List[Tensor],
- img_meta: dict,
- cfg: OptConfigType = None,
- rescale: bool = False,
- with_nms: bool = True) -> InstanceData:
- """Transform a single image's features extracted from the head into
- bbox results.
- Args:
- cls_score_list (list[Tensor]): Box scores from all scale
- levels of a single image, each item has shape
- (num_priors * num_classes, H, W).
- bbox_pred_list (list[Tensor]): Box energies / deltas from
- all scale levels of a single image, each item has shape
- (num_priors * 4, H, W).
- score_factor_list (list[Tensor]): Score factors from all scale
- levels of a single image, each item has shape
- (num_priors * 1, H, W).
- mlvl_priors (list[Tensor]): Each element in the list is
- the priors of a single level in feature pyramid, has shape
- (num_priors, 4).
- img_meta (dict): Image meta info.
- cfg (:obj:`ConfigDict` or dict, optional): Test / postprocessing
- configuration, if None, test_cfg would be used.
- rescale (bool): If True, return boxes in original image space.
- Default: False.
- with_nms (bool): If True, do nms before return boxes.
- Default: True.
- Returns:
- :obj:`InstanceData`: Detection results of each image
- after the post process.
- Each item usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- """
- cfg = self.test_cfg if cfg is None else cfg
- img_shape = img_meta['img_shape']
- nms_pre = cfg.get('nms_pre', -1)
- mlvl_bboxes = []
- mlvl_scores = []
- mlvl_score_factors = []
- for level_idx, (cls_score, bbox_pred, score_factor, priors) in \
- enumerate(zip(cls_score_list, bbox_pred_list,
- score_factor_list, mlvl_priors)):
- assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
- scores = cls_score.permute(1, 2, 0).reshape(
- -1, self.cls_out_channels).sigmoid()
- bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
- score_factor = score_factor.permute(1, 2, 0).reshape(-1).sigmoid()
- if 0 < nms_pre < scores.shape[0]:
- max_scores, _ = (scores *
- score_factor[:, None]).sqrt().max(dim=1)
- _, topk_inds = max_scores.topk(nms_pre)
- priors = priors[topk_inds, :]
- bbox_pred = bbox_pred[topk_inds, :]
- scores = scores[topk_inds, :]
- score_factor = score_factor[topk_inds]
- bboxes = self.bbox_coder.decode(
- priors, bbox_pred, max_shape=img_shape)
- mlvl_bboxes.append(bboxes)
- mlvl_scores.append(scores)
- mlvl_score_factors.append(score_factor)
- results = InstanceData()
- results.bboxes = torch.cat(mlvl_bboxes)
- results.scores = torch.cat(mlvl_scores)
- results.score_factors = torch.cat(mlvl_score_factors)
- return self._bbox_post_process(results, cfg, rescale, with_nms,
- img_meta)
- def _bbox_post_process(self,
- results: InstanceData,
- cfg: ConfigType,
- rescale: bool = False,
- with_nms: bool = True,
- img_meta: Optional[dict] = None):
- """bbox post-processing method.
- The boxes would be rescaled to the original image scale and do
- the nms operation. Usually with_nms is False is used for aug test.
- Args:
- results (:obj:`InstaceData`): Detection instance results,
- each item has shape (num_bboxes, ).
- cfg (:obj:`ConfigDict` or dict): Test / postprocessing
- configuration, if None, test_cfg would be used.
- rescale (bool): If True, return boxes in original image space.
- Default: False.
- with_nms (bool): If True, do nms before return boxes.
- Default: True.
- img_meta (dict, optional): Image meta info. Defaults to None.
- Returns:
- :obj:`InstanceData`: Detection results of each image
- after the post process.
- Each item usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- """
- if rescale:
- results.bboxes /= results.bboxes.new_tensor(
- img_meta['scale_factor']).repeat((1, 2))
- # Add a dummy background class to the backend when using sigmoid
- # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
- # BG cat_id: num_class
- padding = results.scores.new_zeros(results.scores.shape[0], 1)
- mlvl_scores = torch.cat([results.scores, padding], dim=1)
- mlvl_nms_scores = (mlvl_scores * results.score_factors[:, None]).sqrt()
- det_bboxes, det_labels = multiclass_nms(
- results.bboxes,
- mlvl_nms_scores,
- cfg.score_thr,
- cfg.nms,
- cfg.max_per_img,
- score_factors=None)
- if self.with_score_voting and len(det_bboxes) > 0:
- det_bboxes, det_labels = self.score_voting(det_bboxes, det_labels,
- results.bboxes,
- mlvl_nms_scores,
- cfg.score_thr)
- nms_results = InstanceData()
- nms_results.bboxes = det_bboxes[:, :-1]
- nms_results.scores = det_bboxes[:, -1]
- nms_results.labels = det_labels
- return nms_results
- def score_voting(self, det_bboxes: Tensor, det_labels: Tensor,
- mlvl_bboxes: Tensor, mlvl_nms_scores: Tensor,
- score_thr: float) -> Tuple[Tensor, Tensor]:
- """Implementation of score voting method works on each remaining boxes
- after NMS procedure.
- Args:
- det_bboxes (Tensor): Remaining boxes after NMS procedure,
- with shape (k, 5), each dimension means
- (x1, y1, x2, y2, score).
- det_labels (Tensor): The label of remaining boxes, with shape
- (k, 1),Labels are 0-based.
- mlvl_bboxes (Tensor): All boxes before the NMS procedure,
- with shape (num_anchors,4).
- mlvl_nms_scores (Tensor): The scores of all boxes which is used
- in the NMS procedure, with shape (num_anchors, num_class)
- score_thr (float): The score threshold of bboxes.
- Returns:
- tuple: Usually returns a tuple containing voting results.
- - det_bboxes_voted (Tensor): Remaining boxes after
- score voting procedure, with shape (k, 5), each
- dimension means (x1, y1, x2, y2, score).
- - det_labels_voted (Tensor): Label of remaining bboxes
- after voting, with shape (num_anchors,).
- """
- candidate_mask = mlvl_nms_scores > score_thr
- candidate_mask_nonzeros = candidate_mask.nonzero(as_tuple=False)
- candidate_inds = candidate_mask_nonzeros[:, 0]
- candidate_labels = candidate_mask_nonzeros[:, 1]
- candidate_bboxes = mlvl_bboxes[candidate_inds]
- candidate_scores = mlvl_nms_scores[candidate_mask]
- det_bboxes_voted = []
- det_labels_voted = []
- for cls in range(self.cls_out_channels):
- candidate_cls_mask = candidate_labels == cls
- if not candidate_cls_mask.any():
- continue
- candidate_cls_scores = candidate_scores[candidate_cls_mask]
- candidate_cls_bboxes = candidate_bboxes[candidate_cls_mask]
- det_cls_mask = det_labels == cls
- det_cls_bboxes = det_bboxes[det_cls_mask].view(
- -1, det_bboxes.size(-1))
- det_candidate_ious = bbox_overlaps(det_cls_bboxes[:, :4],
- candidate_cls_bboxes)
- for det_ind in range(len(det_cls_bboxes)):
- single_det_ious = det_candidate_ious[det_ind]
- pos_ious_mask = single_det_ious > 0.01
- pos_ious = single_det_ious[pos_ious_mask]
- pos_bboxes = candidate_cls_bboxes[pos_ious_mask]
- pos_scores = candidate_cls_scores[pos_ious_mask]
- pis = (torch.exp(-(1 - pos_ious)**2 / 0.025) *
- pos_scores)[:, None]
- voted_box = torch.sum(
- pis * pos_bboxes, dim=0) / torch.sum(
- pis, dim=0)
- voted_score = det_cls_bboxes[det_ind][-1:][None, :]
- det_bboxes_voted.append(
- torch.cat((voted_box[None, :], voted_score), dim=1))
- det_labels_voted.append(cls)
- det_bboxes_voted = torch.cat(det_bboxes_voted, dim=0)
- det_labels_voted = det_labels.new_tensor(det_labels_voted)
- return det_bboxes_voted, det_labels_voted
|