123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Union
- import torch
- from mmcv.cnn import build_norm_layer
- from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
- from mmengine import ConfigDict
- from mmengine.model import BaseModule, ModuleList
- from torch import Tensor
- from mmdet.utils import ConfigType, OptConfigType
- class DetrTransformerEncoder(BaseModule):
- """Encoder of DETR.
- Args:
- num_layers (int): Number of encoder layers.
- layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder
- layer. All the layers will share the same config.
- init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
- the initialization. Defaults to None.
- """
- def __init__(self,
- num_layers: int,
- layer_cfg: ConfigType,
- init_cfg: OptConfigType = None) -> None:
- super().__init__(init_cfg=init_cfg)
- self.num_layers = num_layers
- self.layer_cfg = layer_cfg
- self._init_layers()
- def _init_layers(self) -> None:
- """Initialize encoder layers."""
- self.layers = ModuleList([
- DetrTransformerEncoderLayer(**self.layer_cfg)
- for _ in range(self.num_layers)
- ])
- self.embed_dims = self.layers[0].embed_dims
- def forward(self, query: Tensor, query_pos: Tensor,
- key_padding_mask: Tensor, **kwargs) -> Tensor:
- """Forward function of encoder.
- Args:
- query (Tensor): Input queries of encoder, has shape
- (bs, num_queries, dim).
- query_pos (Tensor): The positional embeddings of the queries, has
- shape (bs, num_queries, dim).
- key_padding_mask (Tensor): The `key_padding_mask` of `self_attn`
- input. ByteTensor, has shape (bs, num_queries).
- Returns:
- Tensor: Has shape (bs, num_queries, dim) if `batch_first` is
- `True`, otherwise (num_queries, bs, dim).
- """
- for layer in self.layers:
- query = layer(query, query_pos, key_padding_mask, **kwargs)
- return query
- class DetrTransformerDecoder(BaseModule):
- """Decoder of DETR.
- Args:
- num_layers (int): Number of decoder layers.
- layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder
- layer. All the layers will share the same config.
- post_norm_cfg (:obj:`ConfigDict` or dict, optional): Config of the
- post normalization layer. Defaults to `LN`.
- return_intermediate (bool, optional): Whether to return outputs of
- intermediate layers. Defaults to `True`,
- init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
- the initialization. Defaults to None.
- """
- def __init__(self,
- num_layers: int,
- layer_cfg: ConfigType,
- post_norm_cfg: OptConfigType = dict(type='LN'),
- return_intermediate: bool = True,
- init_cfg: Union[dict, ConfigDict] = None) -> None:
- super().__init__(init_cfg=init_cfg)
- self.layer_cfg = layer_cfg
- self.num_layers = num_layers
- self.post_norm_cfg = post_norm_cfg
- self.return_intermediate = return_intermediate
- self._init_layers()
- def _init_layers(self) -> None:
- """Initialize decoder layers."""
- self.layers = ModuleList([
- DetrTransformerDecoderLayer(**self.layer_cfg)
- for _ in range(self.num_layers)
- ])
- self.embed_dims = self.layers[0].embed_dims
- self.post_norm = build_norm_layer(self.post_norm_cfg,
- self.embed_dims)[1]
- def forward(self, query: Tensor, key: Tensor, value: Tensor,
- query_pos: Tensor, key_pos: Tensor, key_padding_mask: Tensor,
- **kwargs) -> Tensor:
- """Forward function of decoder
- Args:
- query (Tensor): The input query, has shape (bs, num_queries, dim).
- key (Tensor): The input key, has shape (bs, num_keys, dim).
- value (Tensor): The input value with the same shape as `key`.
- query_pos (Tensor): The positional encoding for `query`, with the
- same shape as `query`.
- key_pos (Tensor): The positional encoding for `key`, with the
- same shape as `key`.
- key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn`
- input. ByteTensor, has shape (bs, num_value).
- Returns:
- Tensor: The forwarded results will have shape
- (num_decoder_layers, bs, num_queries, dim) if
- `return_intermediate` is `True` else (1, bs, num_queries, dim).
- """
- intermediate = []
- for layer in self.layers:
- query = layer(
- query,
- key=key,
- value=value,
- query_pos=query_pos,
- key_pos=key_pos,
- key_padding_mask=key_padding_mask,
- **kwargs)
- if self.return_intermediate:
- intermediate.append(self.post_norm(query))
- query = self.post_norm(query)
- if self.return_intermediate:
- return torch.stack(intermediate)
- return query.unsqueeze(0)
- class DetrTransformerEncoderLayer(BaseModule):
- """Implements encoder layer in DETR transformer.
- Args:
- self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self
- attention.
- ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN.
- norm_cfg (:obj:`ConfigDict` or dict, optional): Config for
- normalization layers. All the layers will share the same
- config. Defaults to `LN`.
- init_cfg (:obj:`ConfigDict` or dict, optional): Config to control
- the initialization. Defaults to None.
- """
- def __init__(self,
- self_attn_cfg: OptConfigType = dict(
- embed_dims=256, num_heads=8, dropout=0.0),
- ffn_cfg: OptConfigType = dict(
- embed_dims=256,
- feedforward_channels=1024,
- num_fcs=2,
- ffn_drop=0.,
- act_cfg=dict(type='ReLU', inplace=True)),
- norm_cfg: OptConfigType = dict(type='LN'),
- init_cfg: OptConfigType = None) -> None:
- super().__init__(init_cfg=init_cfg)
- self.self_attn_cfg = self_attn_cfg
- if 'batch_first' not in self.self_attn_cfg:
- self.self_attn_cfg['batch_first'] = True
- else:
- assert self.self_attn_cfg['batch_first'] is True, 'First \
- dimension of all DETRs in mmdet is `batch`, \
- please set `batch_first` flag.'
- self.ffn_cfg = ffn_cfg
- self.norm_cfg = norm_cfg
- self._init_layers()
- def _init_layers(self) -> None:
- """Initialize self-attention, FFN, and normalization."""
- self.self_attn = MultiheadAttention(**self.self_attn_cfg)
- self.embed_dims = self.self_attn.embed_dims
- self.ffn = FFN(**self.ffn_cfg)
- norms_list = [
- build_norm_layer(self.norm_cfg, self.embed_dims)[1]
- for _ in range(2)
- ]
- self.norms = ModuleList(norms_list)
- def forward(self, query: Tensor, query_pos: Tensor,
- key_padding_mask: Tensor, **kwargs) -> Tensor:
- """Forward function of an encoder layer.
- Args:
- query (Tensor): The input query, has shape (bs, num_queries, dim).
- query_pos (Tensor): The positional encoding for query, with
- the same shape as `query`.
- key_padding_mask (Tensor): The `key_padding_mask` of `self_attn`
- input. ByteTensor. has shape (bs, num_queries).
- Returns:
- Tensor: forwarded results, has shape (bs, num_queries, dim).
- """
- query = self.self_attn(
- query=query,
- key=query,
- value=query,
- query_pos=query_pos,
- key_pos=query_pos,
- key_padding_mask=key_padding_mask,
- **kwargs)
- query = self.norms[0](query)
- query = self.ffn(query)
- query = self.norms[1](query)
- return query
- class DetrTransformerDecoderLayer(BaseModule):
- """Implements decoder layer in DETR transformer.
- Args:
- self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self
- attention.
- cross_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for cross
- attention.
- ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN.
- norm_cfg (:obj:`ConfigDict` or dict, optional): Config for
- normalization layers. All the layers will share the same
- config. Defaults to `LN`.
- init_cfg (:obj:`ConfigDict` or dict, optional): Config to control
- the initialization. Defaults to None.
- """
- def __init__(self,
- self_attn_cfg: OptConfigType = dict(
- embed_dims=256,
- num_heads=8,
- dropout=0.0,
- batch_first=True),
- cross_attn_cfg: OptConfigType = dict(
- embed_dims=256,
- num_heads=8,
- dropout=0.0,
- batch_first=True),
- ffn_cfg: OptConfigType = dict(
- embed_dims=256,
- feedforward_channels=1024,
- num_fcs=2,
- ffn_drop=0.,
- act_cfg=dict(type='ReLU', inplace=True),
- ),
- norm_cfg: OptConfigType = dict(type='LN'),
- init_cfg: OptConfigType = None) -> None:
- super().__init__(init_cfg=init_cfg)
- self.self_attn_cfg = self_attn_cfg
- self.cross_attn_cfg = cross_attn_cfg
- if 'batch_first' not in self.self_attn_cfg:
- self.self_attn_cfg['batch_first'] = True
- else:
- assert self.self_attn_cfg['batch_first'] is True, 'First \
- dimension of all DETRs in mmdet is `batch`, \
- please set `batch_first` flag.'
- if 'batch_first' not in self.cross_attn_cfg:
- self.cross_attn_cfg['batch_first'] = True
- else:
- assert self.cross_attn_cfg['batch_first'] is True, 'First \
- dimension of all DETRs in mmdet is `batch`, \
- please set `batch_first` flag.'
- self.ffn_cfg = ffn_cfg
- self.norm_cfg = norm_cfg
- self._init_layers()
- def _init_layers(self) -> None:
- """Initialize self-attention, FFN, and normalization."""
- self.self_attn = MultiheadAttention(**self.self_attn_cfg)
- self.cross_attn = MultiheadAttention(**self.cross_attn_cfg)
- self.embed_dims = self.self_attn.embed_dims
- self.ffn = FFN(**self.ffn_cfg)
- norms_list = [
- build_norm_layer(self.norm_cfg, self.embed_dims)[1]
- for _ in range(3)
- ]
- self.norms = ModuleList(norms_list)
- def forward(self,
- query: Tensor,
- key: Tensor = None,
- value: Tensor = None,
- query_pos: Tensor = None,
- key_pos: Tensor = None,
- self_attn_mask: Tensor = None,
- cross_attn_mask: Tensor = None,
- key_padding_mask: Tensor = None,
- **kwargs) -> Tensor:
- """
- Args:
- query (Tensor): The input query, has shape (bs, num_queries, dim).
- key (Tensor, optional): The input key, has shape (bs, num_keys,
- dim). If `None`, the `query` will be used. Defaults to `None`.
- value (Tensor, optional): The input value, has the same shape as
- `key`, as in `nn.MultiheadAttention.forward`. If `None`, the
- `key` will be used. Defaults to `None`.
- query_pos (Tensor, optional): The positional encoding for `query`,
- has the same shape as `query`. If not `None`, it will be added
- to `query` before forward function. Defaults to `None`.
- key_pos (Tensor, optional): The positional encoding for `key`, has
- the same shape as `key`. If not `None`, it will be added to
- `key` before forward function. If None, and `query_pos` has the
- same shape as `key`, then `query_pos` will be used for
- `key_pos`. Defaults to None.
- self_attn_mask (Tensor, optional): ByteTensor mask, has shape
- (num_queries, num_keys), as in `nn.MultiheadAttention.forward`.
- Defaults to None.
- cross_attn_mask (Tensor, optional): ByteTensor mask, has shape
- (num_queries, num_keys), as in `nn.MultiheadAttention.forward`.
- Defaults to None.
- key_padding_mask (Tensor, optional): The `key_padding_mask` of
- `self_attn` input. ByteTensor, has shape (bs, num_value).
- Defaults to None.
- Returns:
- Tensor: forwarded results, has shape (bs, num_queries, dim).
- """
- query = self.self_attn(
- query=query,
- key=query,
- value=query,
- query_pos=query_pos,
- key_pos=query_pos,
- attn_mask=self_attn_mask,
- **kwargs)
- query = self.norms[0](query)
- query = self.cross_attn(
- query=query,
- key=key,
- value=value,
- query_pos=query_pos,
- key_pos=key_pos,
- attn_mask=cross_attn_mask,
- key_padding_mask=key_padding_mask,
- **kwargs)
- query = self.norms[1](query)
- query = self.ffn(query)
- query = self.norms[2](query)
- return query
|