# Copyright (c) OpenMMLab. All rights reserved. import torch from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.transformer import FFN from torch import Tensor from torch.nn import ModuleList from .detr_layers import DetrTransformerDecoder, DetrTransformerDecoderLayer from .utils import MLP, ConditionalAttention, coordinate_to_encoding class ConditionalDetrTransformerDecoder(DetrTransformerDecoder): """Decoder of Conditional DETR.""" def _init_layers(self) -> None: """Initialize decoder layers and other layers.""" self.layers = ModuleList([ ConditionalDetrTransformerDecoderLayer(**self.layer_cfg) for _ in range(self.num_layers) ]) self.embed_dims = self.layers[0].embed_dims self.post_norm = build_norm_layer(self.post_norm_cfg, self.embed_dims)[1] # conditional detr affline self.query_scale = MLP(self.embed_dims, self.embed_dims, self.embed_dims, 2) self.ref_point_head = MLP(self.embed_dims, self.embed_dims, 2, 2) # we have substitute 'qpos_proj' with 'qpos_sine_proj' except for # the first decoder layer), so 'qpos_proj' should be deleted # in other layers. for layer_id in range(self.num_layers - 1): self.layers[layer_id + 1].cross_attn.qpos_proj = None def forward(self, query: Tensor, key: Tensor = None, query_pos: Tensor = None, key_pos: Tensor = None, key_padding_mask: Tensor = None): """Forward function of decoder. Args: query (Tensor): The input query with shape (bs, num_queries, dim). key (Tensor): The input key with shape (bs, num_keys, dim) If `None`, the `query` will be used. Defaults to `None`. query_pos (Tensor): The positional encoding for `query`, with the same shape as `query`. If not `None`, it will be added to `query` before forward function. Defaults to `None`. key_pos (Tensor): The positional encoding for `key`, with the same shape as `key`. If not `None`, it will be added to `key` before forward function. If `None`, and `query_pos` has the same shape as `key`, then `query_pos` will be used as `key_pos`. Defaults to `None`. key_padding_mask (Tensor): ByteTensor with shape (bs, num_keys). Defaults to `None`. Returns: List[Tensor]: forwarded results with shape (num_decoder_layers, bs, num_queries, dim) if `return_intermediate` is True, otherwise with shape (1, bs, num_queries, dim). References with shape (bs, num_queries, 2). """ reference_unsigmoid = self.ref_point_head( query_pos) # [bs, num_queries, 2] reference = reference_unsigmoid.sigmoid() reference_xy = reference[..., :2] intermediate = [] for layer_id, layer in enumerate(self.layers): if layer_id == 0: pos_transformation = 1 else: pos_transformation = self.query_scale(query) # get sine embedding for the query reference ref_sine_embed = coordinate_to_encoding(coord_tensor=reference_xy) # apply transformation ref_sine_embed = ref_sine_embed * pos_transformation query = layer( query, key=key, query_pos=query_pos, key_pos=key_pos, key_padding_mask=key_padding_mask, ref_sine_embed=ref_sine_embed, is_first=(layer_id == 0)) if self.return_intermediate: intermediate.append(self.post_norm(query)) if self.return_intermediate: return torch.stack(intermediate), reference query = self.post_norm(query) return query.unsqueeze(0), reference class ConditionalDetrTransformerDecoderLayer(DetrTransformerDecoderLayer): """Implements decoder layer in Conditional DETR transformer.""" def _init_layers(self): """Initialize self-attention, cross-attention, FFN, and normalization.""" self.self_attn = ConditionalAttention(**self.self_attn_cfg) self.cross_attn = ConditionalAttention(**self.cross_attn_cfg) self.embed_dims = self.self_attn.embed_dims self.ffn = FFN(**self.ffn_cfg) norms_list = [ build_norm_layer(self.norm_cfg, self.embed_dims)[1] for _ in range(3) ] self.norms = ModuleList(norms_list) def forward(self, query: Tensor, key: Tensor = None, query_pos: Tensor = None, key_pos: Tensor = None, self_attn_masks: Tensor = None, cross_attn_masks: Tensor = None, key_padding_mask: Tensor = None, ref_sine_embed: Tensor = None, is_first: bool = False): """ Args: query (Tensor): The input query, has shape (bs, num_queries, dim) key (Tensor, optional): The input key, has shape (bs, num_keys, dim). If `None`, the `query` will be used. Defaults to `None`. query_pos (Tensor, optional): The positional encoding for `query`, has the same shape as `query`. If not `None`, it will be added to `query` before forward function. Defaults to `None`. ref_sine_embed (Tensor): The positional encoding for query in cross attention, with the same shape as `x`. Defaults to None. key_pos (Tensor, optional): The positional encoding for `key`, has the same shape as `key`. If not None, it will be added to `key` before forward function. If None, and `query_pos` has the same shape as `key`, then `query_pos` will be used for `key_pos`. Defaults to None. self_attn_masks (Tensor, optional): ByteTensor mask, has shape (num_queries, num_keys), Same in `nn.MultiheadAttention. forward`. Defaults to None. cross_attn_masks (Tensor, optional): ByteTensor mask, has shape (num_queries, num_keys), Same in `nn.MultiheadAttention. forward`. Defaults to None. key_padding_mask (Tensor, optional): ByteTensor, has shape (bs, num_keys). Defaults to None. is_first (bool): A indicator to tell whether the current layer is the first layer of the decoder. Defaults to False. Returns: Tensor: Forwarded results, has shape (bs, num_queries, dim). """ query = self.self_attn( query=query, key=query, query_pos=query_pos, key_pos=query_pos, attn_mask=self_attn_masks) query = self.norms[0](query) query = self.cross_attn( query=query, key=key, query_pos=query_pos, key_pos=key_pos, attn_mask=cross_attn_masks, key_padding_mask=key_padding_mask, ref_sine_embed=ref_sine_embed, is_first=is_first) query = self.norms[1](query) query = self.ffn(query) query = self.norms[2](query) return query