# Copyright (c) OpenMMLab. All rights reserved. from typing import Optional, Tuple, Union import torch from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention from mmcv.ops import MultiScaleDeformableAttention from mmengine.model import ModuleList from torch import Tensor, nn from .detr_layers import (DetrTransformerDecoder, DetrTransformerDecoderLayer, DetrTransformerEncoder, DetrTransformerEncoderLayer) from .utils import inverse_sigmoid class DeformableDetrTransformerEncoder(DetrTransformerEncoder): """Transformer encoder of Deformable DETR.""" def _init_layers(self) -> None: """Initialize encoder layers.""" self.layers = ModuleList([ DeformableDetrTransformerEncoderLayer(**self.layer_cfg) for _ in range(self.num_layers) ]) self.embed_dims = self.layers[0].embed_dims def forward(self, query: Tensor, query_pos: Tensor, key_padding_mask: Tensor, spatial_shapes: Tensor, level_start_index: Tensor, valid_ratios: Tensor, **kwargs) -> Tensor: """Forward function of Transformer encoder. Args: query (Tensor): The input query, has shape (bs, num_queries, dim). query_pos (Tensor): The positional encoding for query, has shape (bs, num_queries, dim). key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` input. ByteTensor, has shape (bs, num_queries). spatial_shapes (Tensor): Spatial shapes of features in all levels, has shape (num_levels, 2), last dimension represents (h, w). level_start_index (Tensor): The start index of each level. A tensor has shape (num_levels, ) and can be represented as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. valid_ratios (Tensor): The ratios of the valid width and the valid height relative to the width and the height of features in all levels, has shape (bs, num_levels, 2). Returns: Tensor: Output queries of Transformer encoder, which is also called 'encoder output embeddings' or 'memory', has shape (bs, num_queries, dim) """ reference_points = self.get_encoder_reference_points( spatial_shapes, valid_ratios, device=query.device) for layer in self.layers: query = layer( query=query, query_pos=query_pos, key_padding_mask=key_padding_mask, spatial_shapes=spatial_shapes, level_start_index=level_start_index, valid_ratios=valid_ratios, reference_points=reference_points, **kwargs) return query @staticmethod def get_encoder_reference_points( spatial_shapes: Tensor, valid_ratios: Tensor, device: Union[torch.device, str]) -> Tensor: """Get the reference points used in encoder. Args: spatial_shapes (Tensor): Spatial shapes of features in all levels, has shape (num_levels, 2), last dimension represents (h, w). valid_ratios (Tensor): The ratios of the valid width and the valid height relative to the width and the height of features in all levels, has shape (bs, num_levels, 2). device (obj:`device` or str): The device acquired by the `reference_points`. Returns: Tensor: Reference points used in decoder, has shape (bs, length, num_levels, 2). """ reference_points_list = [] for lvl, (H, W) in enumerate(spatial_shapes): ref_y, ref_x = torch.meshgrid( torch.linspace( 0.5, H - 0.5, H, dtype=torch.float32, device=device), torch.linspace( 0.5, W - 0.5, W, dtype=torch.float32, device=device)) ref_y = ref_y.reshape(-1)[None] / ( valid_ratios[:, None, lvl, 1] * H) ref_x = ref_x.reshape(-1)[None] / ( valid_ratios[:, None, lvl, 0] * W) ref = torch.stack((ref_x, ref_y), -1) reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) # [bs, sum(hw), num_level, 2] reference_points = reference_points[:, :, None] * valid_ratios[:, None] return reference_points class DeformableDetrTransformerDecoder(DetrTransformerDecoder): """Transformer Decoder of Deformable DETR.""" def _init_layers(self) -> None: """Initialize decoder layers.""" self.layers = ModuleList([ DeformableDetrTransformerDecoderLayer(**self.layer_cfg) for _ in range(self.num_layers) ]) self.embed_dims = self.layers[0].embed_dims if self.post_norm_cfg is not None: raise ValueError('There is not post_norm in ' f'{self._get_name()}') def forward(self, query: Tensor, query_pos: Tensor, value: Tensor, key_padding_mask: Tensor, reference_points: Tensor, spatial_shapes: Tensor, level_start_index: Tensor, valid_ratios: Tensor, reg_branches: Optional[nn.Module] = None, **kwargs) -> Tuple[Tensor]: """Forward function of Transformer decoder. Args: query (Tensor): The input queries, has shape (bs, num_queries, dim). query_pos (Tensor): The input positional query, has shape (bs, num_queries, dim). It will be added to `query` before forward function. value (Tensor): The input values, has shape (bs, num_value, dim). key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` input. ByteTensor, has shape (bs, num_value). reference_points (Tensor): The initial reference, has shape (bs, num_queries, 4) with the last dimension arranged as (cx, cy, w, h) when `as_two_stage` is `True`, otherwise has shape (bs, num_queries, 2) with the last dimension arranged as (cx, cy). spatial_shapes (Tensor): Spatial shapes of features in all levels, has shape (num_levels, 2), last dimension represents (h, w). level_start_index (Tensor): The start index of each level. A tensor has shape (num_levels, ) and can be represented as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. valid_ratios (Tensor): The ratios of the valid width and the valid height relative to the width and the height of features in all levels, has shape (bs, num_levels, 2). reg_branches: (obj:`nn.ModuleList`, optional): Used for refining the regression results. Only would be passed when `with_box_refine` is `True`, otherwise would be `None`. Returns: tuple[Tensor]: Outputs of Deformable Transformer Decoder. - output (Tensor): Output embeddings of the last decoder, has shape (num_queries, bs, embed_dims) when `return_intermediate` is `False`. Otherwise, Intermediate output embeddings of all decoder layers, has shape (num_decoder_layers, num_queries, bs, embed_dims). - reference_points (Tensor): The reference of the last decoder layer, has shape (bs, num_queries, 4) when `return_intermediate` is `False`. Otherwise, Intermediate references of all decoder layers, has shape (num_decoder_layers, bs, num_queries, 4). The coordinates are arranged as (cx, cy, w, h) """ output = query intermediate = [] intermediate_reference_points = [] for layer_id, layer in enumerate(self.layers): if reference_points.shape[-1] == 4: reference_points_input = \ reference_points[:, :, None] * \ torch.cat([valid_ratios, valid_ratios], -1)[:, None] else: assert reference_points.shape[-1] == 2 reference_points_input = \ reference_points[:, :, None] * \ valid_ratios[:, None] output = layer( output, query_pos=query_pos, value=value, key_padding_mask=key_padding_mask, spatial_shapes=spatial_shapes, level_start_index=level_start_index, valid_ratios=valid_ratios, reference_points=reference_points_input, **kwargs) if reg_branches is not None: tmp_reg_preds = reg_branches[layer_id](output) if reference_points.shape[-1] == 4: new_reference_points = tmp_reg_preds + inverse_sigmoid( reference_points) new_reference_points = new_reference_points.sigmoid() else: assert reference_points.shape[-1] == 2 new_reference_points = tmp_reg_preds new_reference_points[..., :2] = tmp_reg_preds[ ..., :2] + inverse_sigmoid(reference_points) new_reference_points = new_reference_points.sigmoid() reference_points = new_reference_points.detach() if self.return_intermediate: intermediate.append(output) intermediate_reference_points.append(reference_points) if self.return_intermediate: return torch.stack(intermediate), torch.stack( intermediate_reference_points) return output, reference_points class DeformableDetrTransformerEncoderLayer(DetrTransformerEncoderLayer): """Encoder layer of Deformable DETR.""" def _init_layers(self) -> None: """Initialize self_attn, ffn, and norms.""" self.self_attn = MultiScaleDeformableAttention(**self.self_attn_cfg) self.embed_dims = self.self_attn.embed_dims self.ffn = FFN(**self.ffn_cfg) norms_list = [ build_norm_layer(self.norm_cfg, self.embed_dims)[1] for _ in range(2) ] self.norms = ModuleList(norms_list) class DeformableDetrTransformerDecoderLayer(DetrTransformerDecoderLayer): """Decoder layer of Deformable DETR.""" def _init_layers(self) -> None: """Initialize self_attn, cross-attn, ffn, and norms.""" self.self_attn = MultiheadAttention(**self.self_attn_cfg) self.cross_attn = MultiScaleDeformableAttention(**self.cross_attn_cfg) self.embed_dims = self.self_attn.embed_dims self.ffn = FFN(**self.ffn_cfg) norms_list = [ build_norm_layer(self.norm_cfg, self.embed_dims)[1] for _ in range(3) ] self.norms = ModuleList(norms_list)