# Copyright (c) OpenMMLab. All rights reserved. from typing import Union import torch from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention from mmengine import ConfigDict from mmengine.model import BaseModule, ModuleList from torch import Tensor from mmdet.utils import ConfigType, OptConfigType class DetrTransformerEncoder(BaseModule): """Encoder of DETR. Args: num_layers (int): Number of encoder layers. layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder layer. All the layers will share the same config. init_cfg (:obj:`ConfigDict` or dict, optional): the config to control the initialization. Defaults to None. """ def __init__(self, num_layers: int, layer_cfg: ConfigType, init_cfg: OptConfigType = None) -> None: super().__init__(init_cfg=init_cfg) self.num_layers = num_layers self.layer_cfg = layer_cfg self._init_layers() def _init_layers(self) -> None: """Initialize encoder layers.""" self.layers = ModuleList([ DetrTransformerEncoderLayer(**self.layer_cfg) for _ in range(self.num_layers) ]) self.embed_dims = self.layers[0].embed_dims def forward(self, query: Tensor, query_pos: Tensor, key_padding_mask: Tensor, **kwargs) -> Tensor: """Forward function of encoder. Args: query (Tensor): Input queries of encoder, has shape (bs, num_queries, dim). query_pos (Tensor): The positional embeddings of the queries, has shape (bs, num_queries, dim). key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` input. ByteTensor, has shape (bs, num_queries). Returns: Tensor: Has shape (bs, num_queries, dim) if `batch_first` is `True`, otherwise (num_queries, bs, dim). """ for layer in self.layers: query = layer(query, query_pos, key_padding_mask, **kwargs) return query class DetrTransformerDecoder(BaseModule): """Decoder of DETR. Args: num_layers (int): Number of decoder layers. layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder layer. All the layers will share the same config. post_norm_cfg (:obj:`ConfigDict` or dict, optional): Config of the post normalization layer. Defaults to `LN`. return_intermediate (bool, optional): Whether to return outputs of intermediate layers. Defaults to `True`, init_cfg (:obj:`ConfigDict` or dict, optional): the config to control the initialization. Defaults to None. """ def __init__(self, num_layers: int, layer_cfg: ConfigType, post_norm_cfg: OptConfigType = dict(type='LN'), return_intermediate: bool = True, init_cfg: Union[dict, ConfigDict] = None) -> None: super().__init__(init_cfg=init_cfg) self.layer_cfg = layer_cfg self.num_layers = num_layers self.post_norm_cfg = post_norm_cfg self.return_intermediate = return_intermediate self._init_layers() def _init_layers(self) -> None: """Initialize decoder layers.""" self.layers = ModuleList([ DetrTransformerDecoderLayer(**self.layer_cfg) for _ in range(self.num_layers) ]) self.embed_dims = self.layers[0].embed_dims self.post_norm = build_norm_layer(self.post_norm_cfg, self.embed_dims)[1] def forward(self, query: Tensor, key: Tensor, value: Tensor, query_pos: Tensor, key_pos: Tensor, key_padding_mask: Tensor, **kwargs) -> Tensor: """Forward function of decoder Args: query (Tensor): The input query, has shape (bs, num_queries, dim). key (Tensor): The input key, has shape (bs, num_keys, dim). value (Tensor): The input value with the same shape as `key`. query_pos (Tensor): The positional encoding for `query`, with the same shape as `query`. key_pos (Tensor): The positional encoding for `key`, with the same shape as `key`. key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` input. ByteTensor, has shape (bs, num_value). Returns: Tensor: The forwarded results will have shape (num_decoder_layers, bs, num_queries, dim) if `return_intermediate` is `True` else (1, bs, num_queries, dim). """ intermediate = [] for layer in self.layers: query = layer( query, key=key, value=value, query_pos=query_pos, key_pos=key_pos, key_padding_mask=key_padding_mask, **kwargs) if self.return_intermediate: intermediate.append(self.post_norm(query)) query = self.post_norm(query) if self.return_intermediate: return torch.stack(intermediate) return query.unsqueeze(0) class DetrTransformerEncoderLayer(BaseModule): """Implements encoder layer in DETR transformer. Args: self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self attention. ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN. norm_cfg (:obj:`ConfigDict` or dict, optional): Config for normalization layers. All the layers will share the same config. Defaults to `LN`. init_cfg (:obj:`ConfigDict` or dict, optional): Config to control the initialization. Defaults to None. """ def __init__(self, self_attn_cfg: OptConfigType = dict( embed_dims=256, num_heads=8, dropout=0.0), ffn_cfg: OptConfigType = dict( embed_dims=256, feedforward_channels=1024, num_fcs=2, ffn_drop=0., act_cfg=dict(type='ReLU', inplace=True)), norm_cfg: OptConfigType = dict(type='LN'), init_cfg: OptConfigType = None) -> None: super().__init__(init_cfg=init_cfg) self.self_attn_cfg = self_attn_cfg if 'batch_first' not in self.self_attn_cfg: self.self_attn_cfg['batch_first'] = True else: assert self.self_attn_cfg['batch_first'] is True, 'First \ dimension of all DETRs in mmdet is `batch`, \ please set `batch_first` flag.' self.ffn_cfg = ffn_cfg self.norm_cfg = norm_cfg self._init_layers() def _init_layers(self) -> None: """Initialize self-attention, FFN, and normalization.""" self.self_attn = MultiheadAttention(**self.self_attn_cfg) self.embed_dims = self.self_attn.embed_dims self.ffn = FFN(**self.ffn_cfg) norms_list = [ build_norm_layer(self.norm_cfg, self.embed_dims)[1] for _ in range(2) ] self.norms = ModuleList(norms_list) def forward(self, query: Tensor, query_pos: Tensor, key_padding_mask: Tensor, **kwargs) -> Tensor: """Forward function of an encoder layer. Args: query (Tensor): The input query, has shape (bs, num_queries, dim). query_pos (Tensor): The positional encoding for query, with the same shape as `query`. key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` input. ByteTensor. has shape (bs, num_queries). Returns: Tensor: forwarded results, has shape (bs, num_queries, dim). """ query = self.self_attn( query=query, key=query, value=query, query_pos=query_pos, key_pos=query_pos, key_padding_mask=key_padding_mask, **kwargs) query = self.norms[0](query) query = self.ffn(query) query = self.norms[1](query) return query class DetrTransformerDecoderLayer(BaseModule): """Implements decoder layer in DETR transformer. Args: self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self attention. cross_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for cross attention. ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN. norm_cfg (:obj:`ConfigDict` or dict, optional): Config for normalization layers. All the layers will share the same config. Defaults to `LN`. init_cfg (:obj:`ConfigDict` or dict, optional): Config to control the initialization. Defaults to None. """ def __init__(self, self_attn_cfg: OptConfigType = dict( embed_dims=256, num_heads=8, dropout=0.0, batch_first=True), cross_attn_cfg: OptConfigType = dict( embed_dims=256, num_heads=8, dropout=0.0, batch_first=True), ffn_cfg: OptConfigType = dict( embed_dims=256, feedforward_channels=1024, num_fcs=2, ffn_drop=0., act_cfg=dict(type='ReLU', inplace=True), ), norm_cfg: OptConfigType = dict(type='LN'), init_cfg: OptConfigType = None) -> None: super().__init__(init_cfg=init_cfg) self.self_attn_cfg = self_attn_cfg self.cross_attn_cfg = cross_attn_cfg if 'batch_first' not in self.self_attn_cfg: self.self_attn_cfg['batch_first'] = True else: assert self.self_attn_cfg['batch_first'] is True, 'First \ dimension of all DETRs in mmdet is `batch`, \ please set `batch_first` flag.' if 'batch_first' not in self.cross_attn_cfg: self.cross_attn_cfg['batch_first'] = True else: assert self.cross_attn_cfg['batch_first'] is True, 'First \ dimension of all DETRs in mmdet is `batch`, \ please set `batch_first` flag.' self.ffn_cfg = ffn_cfg self.norm_cfg = norm_cfg self._init_layers() def _init_layers(self) -> None: """Initialize self-attention, FFN, and normalization.""" self.self_attn = MultiheadAttention(**self.self_attn_cfg) self.cross_attn = MultiheadAttention(**self.cross_attn_cfg) self.embed_dims = self.self_attn.embed_dims self.ffn = FFN(**self.ffn_cfg) norms_list = [ build_norm_layer(self.norm_cfg, self.embed_dims)[1] for _ in range(3) ] self.norms = ModuleList(norms_list) def forward(self, query: Tensor, key: Tensor = None, value: Tensor = None, query_pos: Tensor = None, key_pos: Tensor = None, self_attn_mask: Tensor = None, cross_attn_mask: Tensor = None, key_padding_mask: Tensor = None, **kwargs) -> Tensor: """ Args: query (Tensor): The input query, has shape (bs, num_queries, dim). key (Tensor, optional): The input key, has shape (bs, num_keys, dim). If `None`, the `query` will be used. Defaults to `None`. value (Tensor, optional): The input value, has the same shape as `key`, as in `nn.MultiheadAttention.forward`. If `None`, the `key` will be used. Defaults to `None`. query_pos (Tensor, optional): The positional encoding for `query`, has the same shape as `query`. If not `None`, it will be added to `query` before forward function. Defaults to `None`. key_pos (Tensor, optional): The positional encoding for `key`, has the same shape as `key`. If not `None`, it will be added to `key` before forward function. If None, and `query_pos` has the same shape as `key`, then `query_pos` will be used for `key_pos`. Defaults to None. self_attn_mask (Tensor, optional): ByteTensor mask, has shape (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. Defaults to None. cross_attn_mask (Tensor, optional): ByteTensor mask, has shape (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. Defaults to None. key_padding_mask (Tensor, optional): The `key_padding_mask` of `self_attn` input. ByteTensor, has shape (bs, num_value). Defaults to None. Returns: Tensor: forwarded results, has shape (bs, num_queries, dim). """ query = self.self_attn( query=query, key=query, value=query, query_pos=query_pos, key_pos=query_pos, attn_mask=self_attn_mask, **kwargs) query = self.norms[0](query) query = self.cross_attn( query=query, key=key, value=value, query_pos=query_pos, key_pos=key_pos, attn_mask=cross_attn_mask, key_padding_mask=key_padding_mask, **kwargs) query = self.norms[1](query) query = self.ffn(query) query = self.norms[2](query) return query