# Copyright (c) OpenMMLab. All rights reserved. from typing import Tuple import torch import torch.nn as nn from mmengine.model import bias_init_with_prob from torch import Tensor from mmdet.models.layers.transformer import inverse_sigmoid from mmdet.registry import MODELS from mmdet.structures import SampleList from mmdet.utils import InstanceList from .detr_head import DETRHead @MODELS.register_module() class ConditionalDETRHead(DETRHead): """Head of Conditional DETR. Conditional DETR: Conditional DETR for Fast Training Convergence. More details can be found in the `paper. `_ . """ def init_weights(self): """Initialize weights of the transformer head.""" super().init_weights() # The initialization below for transformer head is very # important as we use Focal_loss for loss_cls if self.loss_cls.use_sigmoid: bias_init = bias_init_with_prob(0.01) nn.init.constant_(self.fc_cls.bias, bias_init) def forward(self, hidden_states: Tensor, references: Tensor) -> Tuple[Tensor, Tensor]: """"Forward function. Args: hidden_states (Tensor): Features from transformer decoder. If `return_intermediate_dec` is True output has shape (num_decoder_layers, bs, num_queries, dim), else has shape (1, bs, num_queries, dim) which only contains the last layer outputs. references (Tensor): References from transformer decoder, has shape (bs, num_queries, 2). Returns: tuple[Tensor]: results of head containing the following tensor. - layers_cls_scores (Tensor): Outputs from the classification head, shape (num_decoder_layers, bs, num_queries, cls_out_channels). Note cls_out_channels should include background. - layers_bbox_preds (Tensor): Sigmoid outputs from the regression head with normalized coordinate format (cx, cy, w, h), has shape (num_decoder_layers, bs, num_queries, 4). """ references_unsigmoid = inverse_sigmoid(references) layers_bbox_preds = [] for layer_id in range(hidden_states.shape[0]): tmp_reg_preds = self.fc_reg( self.activate(self.reg_ffn(hidden_states[layer_id]))) tmp_reg_preds[..., :2] += references_unsigmoid outputs_coord = tmp_reg_preds.sigmoid() layers_bbox_preds.append(outputs_coord) layers_bbox_preds = torch.stack(layers_bbox_preds) layers_cls_scores = self.fc_cls(hidden_states) return layers_cls_scores, layers_bbox_preds def loss(self, hidden_states: Tensor, references: Tensor, batch_data_samples: SampleList) -> dict: """Perform forward propagation and loss calculation of the detection head on the features of the upstream network. Args: hidden_states (Tensor): Features from the transformer decoder, has shape (num_decoder_layers, bs, num_queries, dim). references (Tensor): References from the transformer decoder, has shape (num_decoder_layers, bs, num_queries, 2). batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. Returns: dict: A dictionary of loss components. """ batch_gt_instances = [] batch_img_metas = [] for data_sample in batch_data_samples: batch_img_metas.append(data_sample.metainfo) batch_gt_instances.append(data_sample.gt_instances) outs = self(hidden_states, references) loss_inputs = outs + (batch_gt_instances, batch_img_metas) losses = self.loss_by_feat(*loss_inputs) return losses def loss_and_predict( self, hidden_states: Tensor, references: Tensor, batch_data_samples: SampleList) -> Tuple[dict, InstanceList]: """Perform forward propagation of the head, then calculate loss and predictions from the features and data samples. Over-write because img_metas are needed as inputs for bbox_head. Args: hidden_states (Tensor): Features from the transformer decoder, has shape (num_decoder_layers, bs, num_queries, dim). references (Tensor): References from the transformer decoder, has shape (num_decoder_layers, bs, num_queries, 2). batch_data_samples (list[:obj:`DetDataSample`]): Each item contains the meta information of each image and corresponding annotations. Returns: tuple: The return value is a tuple contains: - losses: (dict[str, Tensor]): A dictionary of loss components. - predictions (list[:obj:`InstanceData`]): Detection results of each image after the post process. """ batch_gt_instances = [] batch_img_metas = [] for data_sample in batch_data_samples: batch_img_metas.append(data_sample.metainfo) batch_gt_instances.append(data_sample.gt_instances) outs = self(hidden_states, references) loss_inputs = outs + (batch_gt_instances, batch_img_metas) losses = self.loss_by_feat(*loss_inputs) predictions = self.predict_by_feat( *outs, batch_img_metas=batch_img_metas) return losses, predictions def predict(self, hidden_states: Tensor, references: Tensor, batch_data_samples: SampleList, rescale: bool = True) -> InstanceList: """Perform forward propagation of the detection head and predict detection results on the features of the upstream network. Over-write because img_metas are needed as inputs for bbox_head. Args: hidden_states (Tensor): Features from the transformer decoder, has shape (num_decoder_layers, bs, num_queries, dim). references (Tensor): References from the transformer decoder, has shape (num_decoder_layers, bs, num_queries, 2). batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. rescale (bool, optional): Whether to rescale the results. Defaults to True. Returns: list[obj:`InstanceData`]: Detection results of each image after the post process. """ batch_img_metas = [ data_samples.metainfo for data_samples in batch_data_samples ] last_layer_hidden_state = hidden_states[-1].unsqueeze(0) outs = self(last_layer_hidden_state, references) predictions = self.predict_by_feat( *outs, batch_img_metas=batch_img_metas, rescale=rescale) return predictions