# Copyright (c) OpenMMLab. All rights reserved. from mmcv.cnn import build_norm_layer from mmengine.model import ModuleList from torch import Tensor from .deformable_detr_layers import DeformableDetrTransformerEncoder from .detr_layers import DetrTransformerDecoder, DetrTransformerDecoderLayer class Mask2FormerTransformerEncoder(DeformableDetrTransformerEncoder): """Encoder in PixelDecoder of Mask2Former.""" def forward(self, query: Tensor, query_pos: Tensor, key_padding_mask: Tensor, spatial_shapes: Tensor, level_start_index: Tensor, valid_ratios: Tensor, reference_points: Tensor, **kwargs) -> Tensor: """Forward function of Transformer encoder. Args: query (Tensor): The input query, has shape (bs, num_queries, dim). query_pos (Tensor): The positional encoding for query, has shape (bs, num_queries, dim). If not None, it will be added to the `query` before forward function. Defaults to None. key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` input. ByteTensor, has shape (bs, num_queries). spatial_shapes (Tensor): Spatial shapes of features in all levels, has shape (num_levels, 2), last dimension represents (h, w). level_start_index (Tensor): The start index of each level. A tensor has shape (num_levels, ) and can be represented as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. valid_ratios (Tensor): The ratios of the valid width and the valid height relative to the width and the height of features in all levels, has shape (bs, num_levels, 2). reference_points (Tensor): The initial reference, has shape (bs, num_queries, 2) with the last dimension arranged as (cx, cy). Returns: Tensor: Output queries of Transformer encoder, which is also called 'encoder output embeddings' or 'memory', has shape (bs, num_queries, dim) """ for layer in self.layers: query = layer( query=query, query_pos=query_pos, key_padding_mask=key_padding_mask, spatial_shapes=spatial_shapes, level_start_index=level_start_index, valid_ratios=valid_ratios, reference_points=reference_points, **kwargs) return query class Mask2FormerTransformerDecoder(DetrTransformerDecoder): """Decoder of Mask2Former.""" def _init_layers(self) -> None: """Initialize decoder layers.""" self.layers = ModuleList([ Mask2FormerTransformerDecoderLayer(**self.layer_cfg) for _ in range(self.num_layers) ]) self.embed_dims = self.layers[0].embed_dims self.post_norm = build_norm_layer(self.post_norm_cfg, self.embed_dims)[1] class Mask2FormerTransformerDecoderLayer(DetrTransformerDecoderLayer): """Implements decoder layer in Mask2Former transformer.""" def forward(self, query: Tensor, key: Tensor = None, value: Tensor = None, query_pos: Tensor = None, key_pos: Tensor = None, self_attn_mask: Tensor = None, cross_attn_mask: Tensor = None, key_padding_mask: Tensor = None, **kwargs) -> Tensor: """ Args: query (Tensor): The input query, has shape (bs, num_queries, dim). key (Tensor, optional): The input key, has shape (bs, num_keys, dim). If `None`, the `query` will be used. Defaults to `None`. value (Tensor, optional): The input value, has the same shape as `key`, as in `nn.MultiheadAttention.forward`. If `None`, the `key` will be used. Defaults to `None`. query_pos (Tensor, optional): The positional encoding for `query`, has the same shape as `query`. If not `None`, it will be added to `query` before forward function. Defaults to `None`. key_pos (Tensor, optional): The positional encoding for `key`, has the same shape as `key`. If not `None`, it will be added to `key` before forward function. If None, and `query_pos` has the same shape as `key`, then `query_pos` will be used for `key_pos`. Defaults to None. self_attn_mask (Tensor, optional): ByteTensor mask, has shape (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. Defaults to None. cross_attn_mask (Tensor, optional): ByteTensor mask, has shape (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. Defaults to None. key_padding_mask (Tensor, optional): The `key_padding_mask` of `self_attn` input. ByteTensor, has shape (bs, num_value). Defaults to None. Returns: Tensor: forwarded results, has shape (bs, num_queries, dim). """ query = self.cross_attn( query=query, key=key, value=value, query_pos=query_pos, key_pos=key_pos, attn_mask=cross_attn_mask, key_padding_mask=key_padding_mask, **kwargs) query = self.norms[0](query) query = self.self_attn( query=query, key=query, value=query, query_pos=query_pos, key_pos=query_pos, attn_mask=self_attn_mask, **kwargs) query = self.norms[1](query) query = self.ffn(query) query = self.norms[2](query) return query