123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import copy
- from typing import Dict, List, Tuple
- import torch
- import torch.nn as nn
- from mmcv.cnn import Linear
- from mmengine.model import bias_init_with_prob, constant_init
- from torch import Tensor
- from mmdet.registry import MODELS
- from mmdet.structures import SampleList
- from mmdet.utils import InstanceList, OptInstanceList
- from ..layers import inverse_sigmoid
- from .detr_head import DETRHead
- @MODELS.register_module()
- class DeformableDETRHead(DETRHead):
- r"""Head of DeformDETR: Deformable DETR: Deformable Transformers for
- End-to-End Object Detection.
- Code is modified from the `official github repo
- <https://github.com/fundamentalvision/Deformable-DETR>`_.
- More details can be found in the `paper
- <https://arxiv.org/abs/2010.04159>`_ .
- Args:
- share_pred_layer (bool): Whether to share parameters for all the
- prediction layers. Defaults to `False`.
- num_pred_layer (int): The number of the prediction layers.
- Defaults to 6.
- as_two_stage (bool, optional): Whether to generate the proposal
- from the outputs of encoder. Defaults to `False`.
- """
- def __init__(self,
- *args,
- share_pred_layer: bool = False,
- num_pred_layer: int = 6,
- as_two_stage: bool = False,
- **kwargs) -> None:
- self.share_pred_layer = share_pred_layer
- self.num_pred_layer = num_pred_layer
- self.as_two_stage = as_two_stage
- super().__init__(*args, **kwargs)
- def _init_layers(self) -> None:
- """Initialize classification branch and regression branch of head."""
- fc_cls = Linear(self.embed_dims, self.cls_out_channels)
- reg_branch = []
- for _ in range(self.num_reg_fcs):
- reg_branch.append(Linear(self.embed_dims, self.embed_dims))
- reg_branch.append(nn.ReLU())
- reg_branch.append(Linear(self.embed_dims, 4))
- reg_branch = nn.Sequential(*reg_branch)
- if self.share_pred_layer:
- self.cls_branches = nn.ModuleList(
- [fc_cls for _ in range(self.num_pred_layer)])
- self.reg_branches = nn.ModuleList(
- [reg_branch for _ in range(self.num_pred_layer)])
- else:
- self.cls_branches = nn.ModuleList(
- [copy.deepcopy(fc_cls) for _ in range(self.num_pred_layer)])
- self.reg_branches = nn.ModuleList([
- copy.deepcopy(reg_branch) for _ in range(self.num_pred_layer)
- ])
- def init_weights(self) -> None:
- """Initialize weights of the Deformable DETR head."""
- if self.loss_cls.use_sigmoid:
- bias_init = bias_init_with_prob(0.01)
- for m in self.cls_branches:
- nn.init.constant_(m.bias, bias_init)
- for m in self.reg_branches:
- constant_init(m[-1], 0, bias=0)
- nn.init.constant_(self.reg_branches[0][-1].bias.data[2:], -2.0)
- if self.as_two_stage:
- for m in self.reg_branches:
- nn.init.constant_(m[-1].bias.data[2:], 0.0)
- def forward(self, hidden_states: Tensor,
- references: List[Tensor]) -> Tuple[Tensor]:
- """Forward function.
- Args:
- hidden_states (Tensor): Hidden states output from each decoder
- layer, has shape (num_decoder_layers, bs, num_queries, dim).
- references (list[Tensor]): List of the reference from the decoder.
- The first reference is the `init_reference` (initial) and the
- other num_decoder_layers(6) references are `inter_references`
- (intermediate). The `init_reference` has shape (bs,
- num_queries, 4) when `as_two_stage` of the detector is `True`,
- otherwise (bs, num_queries, 2). Each `inter_reference` has
- shape (bs, num_queries, 4) when `with_box_refine` of the
- detector is `True`, otherwise (bs, num_queries, 2). The
- coordinates are arranged as (cx, cy) when the last dimension is
- 2, and (cx, cy, w, h) when it is 4.
- Returns:
- tuple[Tensor]: results of head containing the following tensor.
- - all_layers_outputs_classes (Tensor): Outputs from the
- classification head, has shape (num_decoder_layers, bs,
- num_queries, cls_out_channels).
- - all_layers_outputs_coords (Tensor): Sigmoid outputs from the
- regression head with normalized coordinate format (cx, cy, w,
- h), has shape (num_decoder_layers, bs, num_queries, 4) with the
- last dimension arranged as (cx, cy, w, h).
- """
- all_layers_outputs_classes = []
- all_layers_outputs_coords = []
- for layer_id in range(hidden_states.shape[0]):
- reference = inverse_sigmoid(references[layer_id])
- # NOTE The last reference will not be used.
- hidden_state = hidden_states[layer_id]
- outputs_class = self.cls_branches[layer_id](hidden_state)
- tmp_reg_preds = self.reg_branches[layer_id](hidden_state)
- if reference.shape[-1] == 4:
- # When `layer` is 0 and `as_two_stage` of the detector
- # is `True`, or when `layer` is greater than 0 and
- # `with_box_refine` of the detector is `True`.
- tmp_reg_preds += reference
- else:
- # When `layer` is 0 and `as_two_stage` of the detector
- # is `False`, or when `layer` is greater than 0 and
- # `with_box_refine` of the detector is `False`.
- assert reference.shape[-1] == 2
- tmp_reg_preds[..., :2] += reference
- outputs_coord = tmp_reg_preds.sigmoid()
- all_layers_outputs_classes.append(outputs_class)
- all_layers_outputs_coords.append(outputs_coord)
- all_layers_outputs_classes = torch.stack(all_layers_outputs_classes)
- all_layers_outputs_coords = torch.stack(all_layers_outputs_coords)
- return all_layers_outputs_classes, all_layers_outputs_coords
- def loss(self, hidden_states: Tensor, references: List[Tensor],
- enc_outputs_class: Tensor, enc_outputs_coord: Tensor,
- batch_data_samples: SampleList) -> dict:
- """Perform forward propagation and loss calculation of the detection
- head on the queries of the upstream network.
- Args:
- hidden_states (Tensor): Hidden states output from each decoder
- layer, has shape (num_decoder_layers, num_queries, bs, dim).
- references (list[Tensor]): List of the reference from the decoder.
- The first reference is the `init_reference` (initial) and the
- other num_decoder_layers(6) references are `inter_references`
- (intermediate). The `init_reference` has shape (bs,
- num_queries, 4) when `as_two_stage` of the detector is `True`,
- otherwise (bs, num_queries, 2). Each `inter_reference` has
- shape (bs, num_queries, 4) when `with_box_refine` of the
- detector is `True`, otherwise (bs, num_queries, 2). The
- coordinates are arranged as (cx, cy) when the last dimension is
- 2, and (cx, cy, w, h) when it is 4.
- enc_outputs_class (Tensor): The score of each point on encode
- feature map, has shape (bs, num_feat_points, cls_out_channels).
- Only when `as_two_stage` is `True` it would be passed in,
- otherwise it would be `None`.
- enc_outputs_coord (Tensor): The proposal generate from the encode
- feature map, has shape (bs, num_feat_points, 4) with the last
- dimension arranged as (cx, cy, w, h). Only when `as_two_stage`
- is `True` it would be passed in, otherwise it would be `None`.
- batch_data_samples (list[:obj:`DetDataSample`]): The Data
- Samples. It usually includes information such as
- `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
- Returns:
- dict: A dictionary of loss components.
- """
- batch_gt_instances = []
- batch_img_metas = []
- for data_sample in batch_data_samples:
- batch_img_metas.append(data_sample.metainfo)
- batch_gt_instances.append(data_sample.gt_instances)
- outs = self(hidden_states, references)
- loss_inputs = outs + (enc_outputs_class, enc_outputs_coord,
- batch_gt_instances, batch_img_metas)
- losses = self.loss_by_feat(*loss_inputs)
- return losses
- def loss_by_feat(
- self,
- all_layers_cls_scores: Tensor,
- all_layers_bbox_preds: Tensor,
- enc_cls_scores: Tensor,
- enc_bbox_preds: Tensor,
- batch_gt_instances: InstanceList,
- batch_img_metas: List[dict],
- batch_gt_instances_ignore: OptInstanceList = None
- ) -> Dict[str, Tensor]:
- """Loss function.
- Args:
- all_layers_cls_scores (Tensor): Classification scores of all
- decoder layers, has shape (num_decoder_layers, bs, num_queries,
- cls_out_channels).
- all_layers_bbox_preds (Tensor): Regression outputs of all decoder
- layers. Each is a 4D-tensor with normalized coordinate format
- (cx, cy, w, h) and has shape (num_decoder_layers, bs,
- num_queries, 4) with the last dimension arranged as
- (cx, cy, w, h).
- enc_cls_scores (Tensor): The score of each point on encode
- feature map, has shape (bs, num_feat_points, cls_out_channels).
- Only when `as_two_stage` is `True` it would be passes in,
- otherwise, it would be `None`.
- enc_bbox_preds (Tensor): The proposal generate from the encode
- feature map, has shape (bs, num_feat_points, 4) with the last
- dimension arranged as (cx, cy, w, h). Only when `as_two_stage`
- is `True` it would be passed in, otherwise it would be `None`.
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes`` and ``labels``
- attributes.
- batch_img_metas (list[dict]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
- batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
- Batch of gt_instances_ignore. It includes ``bboxes`` attribute
- data that is ignored during training and testing.
- Defaults to None.
- Returns:
- dict[str, Tensor]: A dictionary of loss components.
- """
- loss_dict = super().loss_by_feat(all_layers_cls_scores,
- all_layers_bbox_preds,
- batch_gt_instances, batch_img_metas,
- batch_gt_instances_ignore)
- # loss of proposal generated from encode feature map.
- if enc_cls_scores is not None:
- proposal_gt_instances = copy.deepcopy(batch_gt_instances)
- for i in range(len(proposal_gt_instances)):
- proposal_gt_instances[i].labels = torch.zeros_like(
- proposal_gt_instances[i].labels)
- enc_loss_cls, enc_losses_bbox, enc_losses_iou = \
- self.loss_by_feat_single(
- enc_cls_scores, enc_bbox_preds,
- batch_gt_instances=proposal_gt_instances,
- batch_img_metas=batch_img_metas)
- loss_dict['enc_loss_cls'] = enc_loss_cls
- loss_dict['enc_loss_bbox'] = enc_losses_bbox
- loss_dict['enc_loss_iou'] = enc_losses_iou
- return loss_dict
- def predict(self,
- hidden_states: Tensor,
- references: List[Tensor],
- batch_data_samples: SampleList,
- rescale: bool = True) -> InstanceList:
- """Perform forward propagation and loss calculation of the detection
- head on the queries of the upstream network.
- Args:
- hidden_states (Tensor): Hidden states output from each decoder
- layer, has shape (num_decoder_layers, num_queries, bs, dim).
- references (list[Tensor]): List of the reference from the decoder.
- The first reference is the `init_reference` (initial) and the
- other num_decoder_layers(6) references are `inter_references`
- (intermediate). The `init_reference` has shape (bs,
- num_queries, 4) when `as_two_stage` of the detector is `True`,
- otherwise (bs, num_queries, 2). Each `inter_reference` has
- shape (bs, num_queries, 4) when `with_box_refine` of the
- detector is `True`, otherwise (bs, num_queries, 2). The
- coordinates are arranged as (cx, cy) when the last dimension is
- 2, and (cx, cy, w, h) when it is 4.
- batch_data_samples (list[:obj:`DetDataSample`]): The Data
- Samples. It usually includes information such as
- `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
- rescale (bool, optional): If `True`, return boxes in original
- image space. Defaults to `True`.
- Returns:
- list[obj:`InstanceData`]: Detection results of each image
- after the post process.
- """
- batch_img_metas = [
- data_samples.metainfo for data_samples in batch_data_samples
- ]
- outs = self(hidden_states, references)
- predictions = self.predict_by_feat(
- *outs, batch_img_metas=batch_img_metas, rescale=rescale)
- return predictions
- def predict_by_feat(self,
- all_layers_cls_scores: Tensor,
- all_layers_bbox_preds: Tensor,
- batch_img_metas: List[Dict],
- rescale: bool = False) -> InstanceList:
- """Transform a batch of output features extracted from the head into
- bbox results.
- Args:
- all_layers_cls_scores (Tensor): Classification scores of all
- decoder layers, has shape (num_decoder_layers, bs, num_queries,
- cls_out_channels).
- all_layers_bbox_preds (Tensor): Regression outputs of all decoder
- layers. Each is a 4D-tensor with normalized coordinate format
- (cx, cy, w, h) and shape (num_decoder_layers, bs, num_queries,
- 4) with the last dimension arranged as (cx, cy, w, h).
- batch_img_metas (list[dict]): Meta information of each image.
- rescale (bool, optional): If `True`, return boxes in original
- image space. Default `False`.
- Returns:
- list[obj:`InstanceData`]: Detection results of each image
- after the post process.
- """
- cls_scores = all_layers_cls_scores[-1]
- bbox_preds = all_layers_bbox_preds[-1]
- result_list = []
- for img_id in range(len(batch_img_metas)):
- cls_score = cls_scores[img_id]
- bbox_pred = bbox_preds[img_id]
- img_meta = batch_img_metas[img_id]
- results = self._predict_by_feat_single(cls_score, bbox_pred,
- img_meta, rescale)
- result_list.append(results)
- return result_list
|