123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from mmcv.cnn import build_norm_layer
- from mmengine.model import ModuleList
- from torch import Tensor
- from .deformable_detr_layers import DeformableDetrTransformerEncoder
- from .detr_layers import DetrTransformerDecoder, DetrTransformerDecoderLayer
- class Mask2FormerTransformerEncoder(DeformableDetrTransformerEncoder):
- """Encoder in PixelDecoder of Mask2Former."""
- def forward(self, query: Tensor, query_pos: Tensor,
- key_padding_mask: Tensor, spatial_shapes: Tensor,
- level_start_index: Tensor, valid_ratios: Tensor,
- reference_points: Tensor, **kwargs) -> Tensor:
- """Forward function of Transformer encoder.
- Args:
- query (Tensor): The input query, has shape (bs, num_queries, dim).
- query_pos (Tensor): The positional encoding for query, has shape
- (bs, num_queries, dim). If not None, it will be added to the
- `query` before forward function. Defaults to None.
- key_padding_mask (Tensor): The `key_padding_mask` of `self_attn`
- input. ByteTensor, has shape (bs, num_queries).
- spatial_shapes (Tensor): Spatial shapes of features in all levels,
- has shape (num_levels, 2), last dimension represents (h, w).
- level_start_index (Tensor): The start index of each level.
- A tensor has shape (num_levels, ) and can be represented
- as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
- valid_ratios (Tensor): The ratios of the valid width and the valid
- height relative to the width and the height of features in all
- levels, has shape (bs, num_levels, 2).
- reference_points (Tensor): The initial reference, has shape
- (bs, num_queries, 2) with the last dimension arranged
- as (cx, cy).
- Returns:
- Tensor: Output queries of Transformer encoder, which is also
- called 'encoder output embeddings' or 'memory', has shape
- (bs, num_queries, dim)
- """
- for layer in self.layers:
- query = layer(
- query=query,
- query_pos=query_pos,
- key_padding_mask=key_padding_mask,
- spatial_shapes=spatial_shapes,
- level_start_index=level_start_index,
- valid_ratios=valid_ratios,
- reference_points=reference_points,
- **kwargs)
- return query
- class Mask2FormerTransformerDecoder(DetrTransformerDecoder):
- """Decoder of Mask2Former."""
- def _init_layers(self) -> None:
- """Initialize decoder layers."""
- self.layers = ModuleList([
- Mask2FormerTransformerDecoderLayer(**self.layer_cfg)
- for _ in range(self.num_layers)
- ])
- self.embed_dims = self.layers[0].embed_dims
- self.post_norm = build_norm_layer(self.post_norm_cfg,
- self.embed_dims)[1]
- class Mask2FormerTransformerDecoderLayer(DetrTransformerDecoderLayer):
- """Implements decoder layer in Mask2Former transformer."""
- def forward(self,
- query: Tensor,
- key: Tensor = None,
- value: Tensor = None,
- query_pos: Tensor = None,
- key_pos: Tensor = None,
- self_attn_mask: Tensor = None,
- cross_attn_mask: Tensor = None,
- key_padding_mask: Tensor = None,
- **kwargs) -> Tensor:
- """
- Args:
- query (Tensor): The input query, has shape (bs, num_queries, dim).
- key (Tensor, optional): The input key, has shape (bs, num_keys,
- dim). If `None`, the `query` will be used. Defaults to `None`.
- value (Tensor, optional): The input value, has the same shape as
- `key`, as in `nn.MultiheadAttention.forward`. If `None`, the
- `key` will be used. Defaults to `None`.
- query_pos (Tensor, optional): The positional encoding for `query`,
- has the same shape as `query`. If not `None`, it will be added
- to `query` before forward function. Defaults to `None`.
- key_pos (Tensor, optional): The positional encoding for `key`, has
- the same shape as `key`. If not `None`, it will be added to
- `key` before forward function. If None, and `query_pos` has the
- same shape as `key`, then `query_pos` will be used for
- `key_pos`. Defaults to None.
- self_attn_mask (Tensor, optional): ByteTensor mask, has shape
- (num_queries, num_keys), as in `nn.MultiheadAttention.forward`.
- Defaults to None.
- cross_attn_mask (Tensor, optional): ByteTensor mask, has shape
- (num_queries, num_keys), as in `nn.MultiheadAttention.forward`.
- Defaults to None.
- key_padding_mask (Tensor, optional): The `key_padding_mask` of
- `self_attn` input. ByteTensor, has shape (bs, num_value).
- Defaults to None.
- Returns:
- Tensor: forwarded results, has shape (bs, num_queries, dim).
- """
- query = self.cross_attn(
- query=query,
- key=key,
- value=value,
- query_pos=query_pos,
- key_pos=key_pos,
- attn_mask=cross_attn_mask,
- key_padding_mask=key_padding_mask,
- **kwargs)
- query = self.norms[0](query)
- query = self.self_attn(
- query=query,
- key=query,
- value=query,
- query_pos=query_pos,
- key_pos=query_pos,
- attn_mask=self_attn_mask,
- **kwargs)
- query = self.norms[1](query)
- query = self.ffn(query)
- query = self.norms[2](query)
- return query
|