# Copyright (c) OpenMMLab. All rights reserved. from typing import List import torch import torch.nn as nn from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.transformer import FFN from mmengine.model import ModuleList from torch import Tensor from .detr_layers import (DetrTransformerDecoder, DetrTransformerDecoderLayer, DetrTransformerEncoder, DetrTransformerEncoderLayer) from .utils import (MLP, ConditionalAttention, coordinate_to_encoding, inverse_sigmoid) class DABDetrTransformerDecoderLayer(DetrTransformerDecoderLayer): """Implements decoder layer in DAB-DETR transformer.""" def _init_layers(self): """Initialize self-attention, cross-attention, FFN, normalization and others.""" self.self_attn = ConditionalAttention(**self.self_attn_cfg) self.cross_attn = ConditionalAttention(**self.cross_attn_cfg) self.embed_dims = self.self_attn.embed_dims self.ffn = FFN(**self.ffn_cfg) norms_list = [ build_norm_layer(self.norm_cfg, self.embed_dims)[1] for _ in range(3) ] self.norms = ModuleList(norms_list) self.keep_query_pos = self.cross_attn.keep_query_pos def forward(self, query: Tensor, key: Tensor, query_pos: Tensor, key_pos: Tensor, ref_sine_embed: Tensor = None, self_attn_masks: Tensor = None, cross_attn_masks: Tensor = None, key_padding_mask: Tensor = None, is_first: bool = False, **kwargs) -> Tensor: """ Args: query (Tensor): The input query with shape [bs, num_queries, dim]. key (Tensor): The key tensor with shape [bs, num_keys, dim]. query_pos (Tensor): The positional encoding for query in self attention, with the same shape as `x`. key_pos (Tensor): The positional encoding for `key`, with the same shape as `key`. ref_sine_embed (Tensor): The positional encoding for query in cross attention, with the same shape as `x`. Defaults to None. self_attn_masks (Tensor): ByteTensor mask with shape [num_queries, num_keys]. Same in `nn.MultiheadAttention.forward`. Defaults to None. cross_attn_masks (Tensor): ByteTensor mask with shape [num_queries, num_keys]. Same in `nn.MultiheadAttention.forward`. Defaults to None. key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. Defaults to None. is_first (bool): A indicator to tell whether the current layer is the first layer of the decoder. Defaults to False. Returns: Tensor: forwarded results with shape [bs, num_queries, dim]. """ query = self.self_attn( query=query, key=query, query_pos=query_pos, key_pos=query_pos, attn_mask=self_attn_masks, **kwargs) query = self.norms[0](query) query = self.cross_attn( query=query, key=key, query_pos=query_pos, key_pos=key_pos, ref_sine_embed=ref_sine_embed, attn_mask=cross_attn_masks, key_padding_mask=key_padding_mask, is_first=is_first, **kwargs) query = self.norms[1](query) query = self.ffn(query) query = self.norms[2](query) return query class DABDetrTransformerDecoder(DetrTransformerDecoder): """Decoder of DAB-DETR. Args: query_dim (int): The last dimension of query pos, 4 for anchor format, 2 for point format. Defaults to 4. query_scale_type (str): Type of transformation applied to content query. Defaults to `cond_elewise`. with_modulated_hw_attn (bool): Whether to inject h&w info during cross conditional attention. Defaults to True. """ def __init__(self, *args, query_dim: int = 4, query_scale_type: str = 'cond_elewise', with_modulated_hw_attn: bool = True, **kwargs): self.query_dim = query_dim self.query_scale_type = query_scale_type self.with_modulated_hw_attn = with_modulated_hw_attn super().__init__(*args, **kwargs) def _init_layers(self): """Initialize decoder layers and other layers.""" assert self.query_dim in [2, 4], \ f'{"dab-detr only supports anchor prior or reference point prior"}' assert self.query_scale_type in [ 'cond_elewise', 'cond_scalar', 'fix_elewise' ] self.layers = ModuleList([ DABDetrTransformerDecoderLayer(**self.layer_cfg) for _ in range(self.num_layers) ]) embed_dims = self.layers[0].embed_dims self.embed_dims = embed_dims self.post_norm = build_norm_layer(self.post_norm_cfg, embed_dims)[1] if self.query_scale_type == 'cond_elewise': self.query_scale = MLP(embed_dims, embed_dims, embed_dims, 2) elif self.query_scale_type == 'cond_scalar': self.query_scale = MLP(embed_dims, embed_dims, 1, 2) elif self.query_scale_type == 'fix_elewise': self.query_scale = nn.Embedding(self.num_layers, embed_dims) else: raise NotImplementedError('Unknown query_scale_type: {}'.format( self.query_scale_type)) self.ref_point_head = MLP(self.query_dim // 2 * embed_dims, embed_dims, embed_dims, 2) if self.with_modulated_hw_attn and self.query_dim == 4: self.ref_anchor_head = MLP(embed_dims, embed_dims, 2, 2) self.keep_query_pos = self.layers[0].keep_query_pos if not self.keep_query_pos: for layer_id in range(self.num_layers - 1): self.layers[layer_id + 1].cross_attn.qpos_proj = None def forward(self, query: Tensor, key: Tensor, query_pos: Tensor, key_pos: Tensor, reg_branches: nn.Module, key_padding_mask: Tensor = None, **kwargs) -> List[Tensor]: """Forward function of decoder. Args: query (Tensor): The input query with shape (bs, num_queries, dim). key (Tensor): The input key with shape (bs, num_keys, dim). query_pos (Tensor): The positional encoding for `query`, with the same shape as `query`. key_pos (Tensor): The positional encoding for `key`, with the same shape as `key`. reg_branches (nn.Module): The regression branch for dynamically updating references in each layer. key_padding_mask (Tensor): ByteTensor with shape (bs, num_keys). Defaults to `None`. Returns: List[Tensor]: forwarded results with shape (num_decoder_layers, bs, num_queries, dim) if `return_intermediate` is True, otherwise with shape (1, bs, num_queries, dim). references with shape (num_decoder_layers, bs, num_queries, 2/4). """ output = query unsigmoid_references = query_pos reference_points = unsigmoid_references.sigmoid() intermediate_reference_points = [reference_points] intermediate = [] for layer_id, layer in enumerate(self.layers): obj_center = reference_points[..., :self.query_dim] ref_sine_embed = coordinate_to_encoding( coord_tensor=obj_center, num_feats=self.embed_dims // 2) query_pos = self.ref_point_head( ref_sine_embed) # [bs, nq, 2c] -> [bs, nq, c] # For the first decoder layer, do not apply transformation if self.query_scale_type != 'fix_elewise': if layer_id == 0: pos_transformation = 1 else: pos_transformation = self.query_scale(output) else: pos_transformation = self.query_scale.weight[layer_id] # apply transformation ref_sine_embed = ref_sine_embed[ ..., :self.embed_dims] * pos_transformation # modulated height and weight attention if self.with_modulated_hw_attn: assert obj_center.size(-1) == 4 ref_hw = self.ref_anchor_head(output).sigmoid() ref_sine_embed[..., self.embed_dims // 2:] *= \ (ref_hw[..., 0] / obj_center[..., 2]).unsqueeze(-1) ref_sine_embed[..., : self.embed_dims // 2] *= \ (ref_hw[..., 1] / obj_center[..., 3]).unsqueeze(-1) output = layer( output, key, query_pos=query_pos, ref_sine_embed=ref_sine_embed, key_pos=key_pos, key_padding_mask=key_padding_mask, is_first=(layer_id == 0), **kwargs) # iter update tmp_reg_preds = reg_branches(output) tmp_reg_preds[..., :self.query_dim] += inverse_sigmoid( reference_points) new_reference_points = tmp_reg_preds[ ..., :self.query_dim].sigmoid() if layer_id != self.num_layers - 1: intermediate_reference_points.append(new_reference_points) reference_points = new_reference_points.detach() if self.return_intermediate: intermediate.append(self.post_norm(output)) output = self.post_norm(output) if self.return_intermediate: return [ torch.stack(intermediate), torch.stack(intermediate_reference_points), ] else: return [ output.unsqueeze(0), torch.stack(intermediate_reference_points) ] class DABDetrTransformerEncoder(DetrTransformerEncoder): """Encoder of DAB-DETR.""" def _init_layers(self): """Initialize encoder layers.""" self.layers = ModuleList([ DetrTransformerEncoderLayer(**self.layer_cfg) for _ in range(self.num_layers) ]) embed_dims = self.layers[0].embed_dims self.embed_dims = embed_dims self.query_scale = MLP(embed_dims, embed_dims, embed_dims, 2) def forward(self, query: Tensor, query_pos: Tensor, key_padding_mask: Tensor, **kwargs): """Forward function of encoder. Args: query (Tensor): Input queries of encoder, has shape (bs, num_queries, dim). query_pos (Tensor): The positional embeddings of the queries, has shape (bs, num_feat_points, dim). key_padding_mask (Tensor): ByteTensor, the key padding mask of the queries, has shape (bs, num_feat_points). Returns: Tensor: With shape (num_queries, bs, dim). """ for layer in self.layers: pos_scales = self.query_scale(query) query = layer( query, query_pos=query_pos * pos_scales, key_padding_mask=key_padding_mask, **kwargs) return query