123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Dict, List, Tuple
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from mmcv.cnn import Linear
- from mmcv.cnn.bricks.transformer import FFN
- from mmengine.model import BaseModule
- from mmengine.structures import InstanceData
- from torch import Tensor
- from mmdet.registry import MODELS, TASK_UTILS
- from mmdet.structures import SampleList
- from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh
- from mmdet.utils import (ConfigType, InstanceList, OptInstanceList,
- OptMultiConfig, reduce_mean)
- from ..utils import multi_apply
- @MODELS.register_module()
- class DETRHead(BaseModule):
- r"""Head of DETR. DETR:End-to-End Object Detection with Transformers.
- More details can be found in the `paper
- <https://arxiv.org/pdf/2005.12872>`_ .
- Args:
- num_classes (int): Number of categories excluding the background.
- embed_dims (int): The dims of Transformer embedding.
- num_reg_fcs (int): Number of fully-connected layers used in `FFN`,
- which is then used for the regression head. Defaults to 2.
- sync_cls_avg_factor (bool): Whether to sync the `avg_factor` of
- all ranks. Default to `False`.
- loss_cls (:obj:`ConfigDict` or dict): Config of the classification
- loss. Defaults to `CrossEntropyLoss`.
- loss_bbox (:obj:`ConfigDict` or dict): Config of the regression bbox
- loss. Defaults to `L1Loss`.
- loss_iou (:obj:`ConfigDict` or dict): Config of the regression iou
- loss. Defaults to `GIoULoss`.
- train_cfg (:obj:`ConfigDict` or dict): Training config of transformer
- head.
- test_cfg (:obj:`ConfigDict` or dict): Testing config of transformer
- head.
- init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
- the initialization. Defaults to None.
- """
- _version = 2
- def __init__(
- self,
- num_classes: int,
- embed_dims: int = 256,
- num_reg_fcs: int = 2,
- sync_cls_avg_factor: bool = False,
- loss_cls: ConfigType = dict(
- type='CrossEntropyLoss',
- bg_cls_weight=0.1,
- use_sigmoid=False,
- loss_weight=1.0,
- class_weight=1.0),
- loss_bbox: ConfigType = dict(type='L1Loss', loss_weight=5.0),
- loss_iou: ConfigType = dict(type='GIoULoss', loss_weight=2.0),
- train_cfg: ConfigType = dict(
- assigner=dict(
- type='HungarianAssigner',
- match_costs=[
- dict(type='ClassificationCost', weight=1.),
- dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
- dict(type='IoUCost', iou_mode='giou', weight=2.0)
- ])),
- test_cfg: ConfigType = dict(max_per_img=100),
- init_cfg: OptMultiConfig = None) -> None:
- super().__init__(init_cfg=init_cfg)
- self.bg_cls_weight = 0
- self.sync_cls_avg_factor = sync_cls_avg_factor
- class_weight = loss_cls.get('class_weight', None)
- if class_weight is not None and (self.__class__ is DETRHead):
- assert isinstance(class_weight, float), 'Expected ' \
- 'class_weight to have type float. Found ' \
- f'{type(class_weight)}.'
- # NOTE following the official DETR repo, bg_cls_weight means
- # relative classification weight of the no-object class.
- bg_cls_weight = loss_cls.get('bg_cls_weight', class_weight)
- assert isinstance(bg_cls_weight, float), 'Expected ' \
- 'bg_cls_weight to have type float. Found ' \
- f'{type(bg_cls_weight)}.'
- class_weight = torch.ones(num_classes + 1) * class_weight
- # set background class as the last indice
- class_weight[num_classes] = bg_cls_weight
- loss_cls.update({'class_weight': class_weight})
- if 'bg_cls_weight' in loss_cls:
- loss_cls.pop('bg_cls_weight')
- self.bg_cls_weight = bg_cls_weight
- if train_cfg:
- assert 'assigner' in train_cfg, 'assigner should be provided ' \
- 'when train_cfg is set.'
- assigner = train_cfg['assigner']
- self.assigner = TASK_UTILS.build(assigner)
- if train_cfg.get('sampler', None) is not None:
- raise RuntimeError('DETR do not build sampler.')
- self.num_classes = num_classes
- self.embed_dims = embed_dims
- self.num_reg_fcs = num_reg_fcs
- self.train_cfg = train_cfg
- self.test_cfg = test_cfg
- self.loss_cls = MODELS.build(loss_cls)
- self.loss_bbox = MODELS.build(loss_bbox)
- self.loss_iou = MODELS.build(loss_iou)
- if self.loss_cls.use_sigmoid:
- self.cls_out_channels = num_classes
- else:
- self.cls_out_channels = num_classes + 1
- self._init_layers()
- def _init_layers(self) -> None:
- """Initialize layers of the transformer head."""
- # cls branch
- self.fc_cls = Linear(self.embed_dims, self.cls_out_channels)
- # reg branch
- self.activate = nn.ReLU()
- self.reg_ffn = FFN(
- self.embed_dims,
- self.embed_dims,
- self.num_reg_fcs,
- dict(type='ReLU', inplace=True),
- dropout=0.0,
- add_residual=False)
- # NOTE the activations of reg_branch here is the same as
- # those in transformer, but they are actually different
- # in DAB-DETR (prelu in transformer and relu in reg_branch)
- self.fc_reg = Linear(self.embed_dims, 4)
- def forward(self, hidden_states: Tensor) -> Tuple[Tensor]:
- """"Forward function.
- Args:
- hidden_states (Tensor): Features from transformer decoder. If
- `return_intermediate_dec` in detr.py is True output has shape
- (num_decoder_layers, bs, num_queries, dim), else has shape
- (1, bs, num_queries, dim) which only contains the last layer
- outputs.
- Returns:
- tuple[Tensor]: results of head containing the following tensor.
- - layers_cls_scores (Tensor): Outputs from the classification head,
- shape (num_decoder_layers, bs, num_queries, cls_out_channels).
- Note cls_out_channels should include background.
- - layers_bbox_preds (Tensor): Sigmoid outputs from the regression
- head with normalized coordinate format (cx, cy, w, h), has shape
- (num_decoder_layers, bs, num_queries, 4).
- """
- layers_cls_scores = self.fc_cls(hidden_states)
- layers_bbox_preds = self.fc_reg(
- self.activate(self.reg_ffn(hidden_states))).sigmoid()
- return layers_cls_scores, layers_bbox_preds
- def loss(self, hidden_states: Tensor,
- batch_data_samples: SampleList) -> dict:
- """Perform forward propagation and loss calculation of the detection
- head on the features of the upstream network.
- Args:
- hidden_states (Tensor): Feature from the transformer decoder, has
- shape (num_decoder_layers, bs, num_queries, cls_out_channels)
- or (num_decoder_layers, num_queries, bs, cls_out_channels).
- batch_data_samples (List[:obj:`DetDataSample`]): The Data
- Samples. It usually includes information such as
- `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
- Returns:
- dict: A dictionary of loss components.
- """
- batch_gt_instances = []
- batch_img_metas = []
- for data_sample in batch_data_samples:
- batch_img_metas.append(data_sample.metainfo)
- batch_gt_instances.append(data_sample.gt_instances)
- outs = self(hidden_states)
- loss_inputs = outs + (batch_gt_instances, batch_img_metas)
- losses = self.loss_by_feat(*loss_inputs)
- return losses
- def loss_by_feat(
- self,
- all_layers_cls_scores: Tensor,
- all_layers_bbox_preds: Tensor,
- batch_gt_instances: InstanceList,
- batch_img_metas: List[dict],
- batch_gt_instances_ignore: OptInstanceList = None
- ) -> Dict[str, Tensor]:
- """"Loss function.
- Only outputs from the last feature level are used for computing
- losses by default.
- Args:
- all_layers_cls_scores (Tensor): Classification outputs
- of each decoder layers. Each is a 4D-tensor, has shape
- (num_decoder_layers, bs, num_queries, cls_out_channels).
- all_layers_bbox_preds (Tensor): Sigmoid regression
- outputs of each decoder layers. Each is a 4D-tensor with
- normalized coordinate format (cx, cy, w, h) and shape
- (num_decoder_layers, bs, num_queries, 4).
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes`` and ``labels``
- attributes.
- batch_img_metas (list[dict]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
- batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
- Batch of gt_instances_ignore. It includes ``bboxes`` attribute
- data that is ignored during training and testing.
- Defaults to None.
- Returns:
- dict[str, Tensor]: A dictionary of loss components.
- """
- assert batch_gt_instances_ignore is None, \
- f'{self.__class__.__name__} only supports ' \
- 'for batch_gt_instances_ignore setting to None.'
- losses_cls, losses_bbox, losses_iou = multi_apply(
- self.loss_by_feat_single,
- all_layers_cls_scores,
- all_layers_bbox_preds,
- batch_gt_instances=batch_gt_instances,
- batch_img_metas=batch_img_metas)
- loss_dict = dict()
- # loss from the last decoder layer
- loss_dict['loss_cls'] = losses_cls[-1]
- loss_dict['loss_bbox'] = losses_bbox[-1]
- loss_dict['loss_iou'] = losses_iou[-1]
- # loss from other decoder layers
- num_dec_layer = 0
- for loss_cls_i, loss_bbox_i, loss_iou_i in \
- zip(losses_cls[:-1], losses_bbox[:-1], losses_iou[:-1]):
- loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
- loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
- loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i
- num_dec_layer += 1
- return loss_dict
- def loss_by_feat_single(self, cls_scores: Tensor, bbox_preds: Tensor,
- batch_gt_instances: InstanceList,
- batch_img_metas: List[dict]) -> Tuple[Tensor]:
- """Loss function for outputs from a single decoder layer of a single
- feature level.
- Args:
- cls_scores (Tensor): Box score logits from a single decoder layer
- for all images, has shape (bs, num_queries, cls_out_channels).
- bbox_preds (Tensor): Sigmoid outputs from a single decoder layer
- for all images, with normalized coordinate (cx, cy, w, h) and
- shape (bs, num_queries, 4).
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes`` and ``labels``
- attributes.
- batch_img_metas (list[dict]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
- Returns:
- Tuple[Tensor]: A tuple including `loss_cls`, `loss_box` and
- `loss_iou`.
- """
- num_imgs = cls_scores.size(0)
- cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
- bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
- cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list,
- batch_gt_instances, batch_img_metas)
- (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
- num_total_pos, num_total_neg) = cls_reg_targets
- labels = torch.cat(labels_list, 0)
- label_weights = torch.cat(label_weights_list, 0)
- bbox_targets = torch.cat(bbox_targets_list, 0)
- bbox_weights = torch.cat(bbox_weights_list, 0)
- # classification loss
- cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
- # construct weighted avg_factor to match with the official DETR repo
- cls_avg_factor = num_total_pos * 1.0 + \
- num_total_neg * self.bg_cls_weight
- if self.sync_cls_avg_factor:
- cls_avg_factor = reduce_mean(
- cls_scores.new_tensor([cls_avg_factor]))
- cls_avg_factor = max(cls_avg_factor, 1)
- loss_cls = self.loss_cls(
- cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
- # Compute the average number of gt boxes across all gpus, for
- # normalization purposes
- num_total_pos = loss_cls.new_tensor([num_total_pos])
- num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
- # construct factors used for rescale bboxes
- factors = []
- for img_meta, bbox_pred in zip(batch_img_metas, bbox_preds):
- img_h, img_w, = img_meta['img_shape']
- factor = bbox_pred.new_tensor([img_w, img_h, img_w,
- img_h]).unsqueeze(0).repeat(
- bbox_pred.size(0), 1)
- factors.append(factor)
- factors = torch.cat(factors, 0)
- # DETR regress the relative position of boxes (cxcywh) in the image,
- # thus the learning target is normalized by the image size. So here
- # we need to re-scale them for calculating IoU loss
- bbox_preds = bbox_preds.reshape(-1, 4)
- bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
- bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors
- # regression IoU loss, defaultly GIoU loss
- loss_iou = self.loss_iou(
- bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos)
- # regression L1 loss
- loss_bbox = self.loss_bbox(
- bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos)
- return loss_cls, loss_bbox, loss_iou
- def get_targets(self, cls_scores_list: List[Tensor],
- bbox_preds_list: List[Tensor],
- batch_gt_instances: InstanceList,
- batch_img_metas: List[dict]) -> tuple:
- """Compute regression and classification targets for a batch image.
- Outputs from a single decoder layer of a single feature level are used.
- Args:
- cls_scores_list (list[Tensor]): Box score logits from a single
- decoder layer for each image, has shape [num_queries,
- cls_out_channels].
- bbox_preds_list (list[Tensor]): Sigmoid outputs from a single
- decoder layer for each image, with normalized coordinate
- (cx, cy, w, h) and shape [num_queries, 4].
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes`` and ``labels``
- attributes.
- batch_img_metas (list[dict]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
- Returns:
- tuple: a tuple containing the following targets.
- - labels_list (list[Tensor]): Labels for all images.
- - label_weights_list (list[Tensor]): Label weights for all images.
- - bbox_targets_list (list[Tensor]): BBox targets for all images.
- - bbox_weights_list (list[Tensor]): BBox weights for all images.
- - num_total_pos (int): Number of positive samples in all images.
- - num_total_neg (int): Number of negative samples in all images.
- """
- (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
- pos_inds_list,
- neg_inds_list) = multi_apply(self._get_targets_single,
- cls_scores_list, bbox_preds_list,
- batch_gt_instances, batch_img_metas)
- num_total_pos = sum((inds.numel() for inds in pos_inds_list))
- num_total_neg = sum((inds.numel() for inds in neg_inds_list))
- return (labels_list, label_weights_list, bbox_targets_list,
- bbox_weights_list, num_total_pos, num_total_neg)
- def _get_targets_single(self, cls_score: Tensor, bbox_pred: Tensor,
- gt_instances: InstanceData,
- img_meta: dict) -> tuple:
- """Compute regression and classification targets for one image.
- Outputs from a single decoder layer of a single feature level are used.
- Args:
- cls_score (Tensor): Box score logits from a single decoder layer
- for one image. Shape [num_queries, cls_out_channels].
- bbox_pred (Tensor): Sigmoid outputs from a single decoder layer
- for one image, with normalized coordinate (cx, cy, w, h) and
- shape [num_queries, 4].
- gt_instances (:obj:`InstanceData`): Ground truth of instance
- annotations. It should includes ``bboxes`` and ``labels``
- attributes.
- img_meta (dict): Meta information for one image.
- Returns:
- tuple[Tensor]: a tuple containing the following for one image.
- - labels (Tensor): Labels of each image.
- - label_weights (Tensor]): Label weights of each image.
- - bbox_targets (Tensor): BBox targets of each image.
- - bbox_weights (Tensor): BBox weights of each image.
- - pos_inds (Tensor): Sampled positive indices for each image.
- - neg_inds (Tensor): Sampled negative indices for each image.
- """
- img_h, img_w = img_meta['img_shape']
- factor = bbox_pred.new_tensor([img_w, img_h, img_w,
- img_h]).unsqueeze(0)
- num_bboxes = bbox_pred.size(0)
- # convert bbox_pred from xywh, normalized to xyxy, unnormalized
- bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred)
- bbox_pred = bbox_pred * factor
- pred_instances = InstanceData(scores=cls_score, bboxes=bbox_pred)
- # assigner and sampler
- assign_result = self.assigner.assign(
- pred_instances=pred_instances,
- gt_instances=gt_instances,
- img_meta=img_meta)
- gt_bboxes = gt_instances.bboxes
- gt_labels = gt_instances.labels
- pos_inds = torch.nonzero(
- assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
- neg_inds = torch.nonzero(
- assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()
- pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
- pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds.long(), :]
- # label targets
- labels = gt_bboxes.new_full((num_bboxes, ),
- self.num_classes,
- dtype=torch.long)
- labels[pos_inds] = gt_labels[pos_assigned_gt_inds]
- label_weights = gt_bboxes.new_ones(num_bboxes)
- # bbox targets
- bbox_targets = torch.zeros_like(bbox_pred)
- bbox_weights = torch.zeros_like(bbox_pred)
- bbox_weights[pos_inds] = 1.0
- # DETR regress the relative position of boxes (cxcywh) in the image.
- # Thus the learning target should be normalized by the image size, also
- # the box format should be converted from defaultly x1y1x2y2 to cxcywh.
- pos_gt_bboxes_normalized = pos_gt_bboxes / factor
- pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized)
- bbox_targets[pos_inds] = pos_gt_bboxes_targets
- return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
- neg_inds)
- def loss_and_predict(
- self, hidden_states: Tuple[Tensor],
- batch_data_samples: SampleList) -> Tuple[dict, InstanceList]:
- """Perform forward propagation of the head, then calculate loss and
- predictions from the features and data samples. Over-write because
- img_metas are needed as inputs for bbox_head.
- Args:
- hidden_states (tuple[Tensor]): Feature from the transformer
- decoder, has shape (num_decoder_layers, bs, num_queries, dim).
- batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
- the meta information of each image and corresponding
- annotations.
- Returns:
- tuple: the return value is a tuple contains:
- - losses: (dict[str, Tensor]): A dictionary of loss components.
- - predictions (list[:obj:`InstanceData`]): Detection
- results of each image after the post process.
- """
- batch_gt_instances = []
- batch_img_metas = []
- for data_sample in batch_data_samples:
- batch_img_metas.append(data_sample.metainfo)
- batch_gt_instances.append(data_sample.gt_instances)
- outs = self(hidden_states)
- loss_inputs = outs + (batch_gt_instances, batch_img_metas)
- losses = self.loss_by_feat(*loss_inputs)
- predictions = self.predict_by_feat(
- *outs, batch_img_metas=batch_img_metas)
- return losses, predictions
- def predict(self,
- hidden_states: Tuple[Tensor],
- batch_data_samples: SampleList,
- rescale: bool = True) -> InstanceList:
- """Perform forward propagation of the detection head and predict
- detection results on the features of the upstream network. Over-write
- because img_metas are needed as inputs for bbox_head.
- Args:
- hidden_states (tuple[Tensor]): Multi-level features from the
- upstream network, each is a 4D-tensor.
- batch_data_samples (List[:obj:`DetDataSample`]): The Data
- Samples. It usually includes information such as
- `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
- rescale (bool, optional): Whether to rescale the results.
- Defaults to True.
- Returns:
- list[obj:`InstanceData`]: Detection results of each image
- after the post process.
- """
- batch_img_metas = [
- data_samples.metainfo for data_samples in batch_data_samples
- ]
- last_layer_hidden_state = hidden_states[-1].unsqueeze(0)
- outs = self(last_layer_hidden_state)
- predictions = self.predict_by_feat(
- *outs, batch_img_metas=batch_img_metas, rescale=rescale)
- return predictions
- def predict_by_feat(self,
- layer_cls_scores: Tensor,
- layer_bbox_preds: Tensor,
- batch_img_metas: List[dict],
- rescale: bool = True) -> InstanceList:
- """Transform network outputs for a batch into bbox predictions.
- Args:
- layer_cls_scores (Tensor): Classification outputs of the last or
- all decoder layer. Each is a 4D-tensor, has shape
- (num_decoder_layers, bs, num_queries, cls_out_channels).
- layer_bbox_preds (Tensor): Sigmoid regression outputs of the last
- or all decoder layer. Each is a 4D-tensor with normalized
- coordinate format (cx, cy, w, h) and shape
- (num_decoder_layers, bs, num_queries, 4).
- batch_img_metas (list[dict]): Meta information of each image.
- rescale (bool, optional): If `True`, return boxes in original
- image space. Defaults to `True`.
- Returns:
- list[:obj:`InstanceData`]: Object detection results of each image
- after the post process. Each item usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- """
- # NOTE only using outputs from the last feature level,
- # and only the outputs from the last decoder layer is used.
- cls_scores = layer_cls_scores[-1]
- bbox_preds = layer_bbox_preds[-1]
- result_list = []
- for img_id in range(len(batch_img_metas)):
- cls_score = cls_scores[img_id]
- bbox_pred = bbox_preds[img_id]
- img_meta = batch_img_metas[img_id]
- results = self._predict_by_feat_single(cls_score, bbox_pred,
- img_meta, rescale)
- result_list.append(results)
- return result_list
- def _predict_by_feat_single(self,
- cls_score: Tensor,
- bbox_pred: Tensor,
- img_meta: dict,
- rescale: bool = True) -> InstanceData:
- """Transform outputs from the last decoder layer into bbox predictions
- for each image.
- Args:
- cls_score (Tensor): Box score logits from the last decoder layer
- for each image. Shape [num_queries, cls_out_channels].
- bbox_pred (Tensor): Sigmoid outputs from the last decoder layer
- for each image, with coordinate format (cx, cy, w, h) and
- shape [num_queries, 4].
- img_meta (dict): Image meta info.
- rescale (bool): If True, return boxes in original image
- space. Default True.
- Returns:
- :obj:`InstanceData`: Detection results of each image
- after the post process.
- Each item usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- """
- assert len(cls_score) == len(bbox_pred) # num_queries
- max_per_img = self.test_cfg.get('max_per_img', len(cls_score))
- img_shape = img_meta['img_shape']
- # exclude background
- if self.loss_cls.use_sigmoid:
- cls_score = cls_score.sigmoid()
- scores, indexes = cls_score.view(-1).topk(max_per_img)
- det_labels = indexes % self.num_classes
- bbox_index = indexes // self.num_classes
- bbox_pred = bbox_pred[bbox_index]
- else:
- scores, det_labels = F.softmax(cls_score, dim=-1)[..., :-1].max(-1)
- scores, bbox_index = scores.topk(max_per_img)
- bbox_pred = bbox_pred[bbox_index]
- det_labels = det_labels[bbox_index]
- det_bboxes = bbox_cxcywh_to_xyxy(bbox_pred)
- det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1]
- det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0]
- det_bboxes[:, 0::2].clamp_(min=0, max=img_shape[1])
- det_bboxes[:, 1::2].clamp_(min=0, max=img_shape[0])
- if rescale:
- assert img_meta.get('scale_factor') is not None
- det_bboxes /= det_bboxes.new_tensor(
- img_meta['scale_factor']).repeat((1, 2))
- results = InstanceData()
- results.bboxes = det_bboxes
- results.scores = scores
- results.labels = det_labels
- return results
|