123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List
- import torch
- import torch.nn as nn
- from mmcv.cnn import build_norm_layer
- from mmcv.cnn.bricks.transformer import FFN
- from mmengine.model import ModuleList
- from torch import Tensor
- from .detr_layers import (DetrTransformerDecoder, DetrTransformerDecoderLayer,
- DetrTransformerEncoder, DetrTransformerEncoderLayer)
- from .utils import (MLP, ConditionalAttention, coordinate_to_encoding,
- inverse_sigmoid)
- class DABDetrTransformerDecoderLayer(DetrTransformerDecoderLayer):
- """Implements decoder layer in DAB-DETR transformer."""
- def _init_layers(self):
- """Initialize self-attention, cross-attention, FFN, normalization and
- others."""
- self.self_attn = ConditionalAttention(**self.self_attn_cfg)
- self.cross_attn = ConditionalAttention(**self.cross_attn_cfg)
- self.embed_dims = self.self_attn.embed_dims
- self.ffn = FFN(**self.ffn_cfg)
- norms_list = [
- build_norm_layer(self.norm_cfg, self.embed_dims)[1]
- for _ in range(3)
- ]
- self.norms = ModuleList(norms_list)
- self.keep_query_pos = self.cross_attn.keep_query_pos
- def forward(self,
- query: Tensor,
- key: Tensor,
- query_pos: Tensor,
- key_pos: Tensor,
- ref_sine_embed: Tensor = None,
- self_attn_masks: Tensor = None,
- cross_attn_masks: Tensor = None,
- key_padding_mask: Tensor = None,
- is_first: bool = False,
- **kwargs) -> Tensor:
- """
- Args:
- query (Tensor): The input query with shape [bs, num_queries,
- dim].
- key (Tensor): The key tensor with shape [bs, num_keys,
- dim].
- query_pos (Tensor): The positional encoding for query in self
- attention, with the same shape as `x`.
- key_pos (Tensor): The positional encoding for `key`, with the
- same shape as `key`.
- ref_sine_embed (Tensor): The positional encoding for query in
- cross attention, with the same shape as `x`.
- Defaults to None.
- self_attn_masks (Tensor): ByteTensor mask with shape [num_queries,
- num_keys]. Same in `nn.MultiheadAttention.forward`.
- Defaults to None.
- cross_attn_masks (Tensor): ByteTensor mask with shape [num_queries,
- num_keys]. Same in `nn.MultiheadAttention.forward`.
- Defaults to None.
- key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
- Defaults to None.
- is_first (bool): A indicator to tell whether the current layer
- is the first layer of the decoder.
- Defaults to False.
- Returns:
- Tensor: forwarded results with shape
- [bs, num_queries, dim].
- """
- query = self.self_attn(
- query=query,
- key=query,
- query_pos=query_pos,
- key_pos=query_pos,
- attn_mask=self_attn_masks,
- **kwargs)
- query = self.norms[0](query)
- query = self.cross_attn(
- query=query,
- key=key,
- query_pos=query_pos,
- key_pos=key_pos,
- ref_sine_embed=ref_sine_embed,
- attn_mask=cross_attn_masks,
- key_padding_mask=key_padding_mask,
- is_first=is_first,
- **kwargs)
- query = self.norms[1](query)
- query = self.ffn(query)
- query = self.norms[2](query)
- return query
- class DABDetrTransformerDecoder(DetrTransformerDecoder):
- """Decoder of DAB-DETR.
- Args:
- query_dim (int): The last dimension of query pos,
- 4 for anchor format, 2 for point format.
- Defaults to 4.
- query_scale_type (str): Type of transformation applied
- to content query. Defaults to `cond_elewise`.
- with_modulated_hw_attn (bool): Whether to inject h&w info
- during cross conditional attention. Defaults to True.
- """
- def __init__(self,
- *args,
- query_dim: int = 4,
- query_scale_type: str = 'cond_elewise',
- with_modulated_hw_attn: bool = True,
- **kwargs):
- self.query_dim = query_dim
- self.query_scale_type = query_scale_type
- self.with_modulated_hw_attn = with_modulated_hw_attn
- super().__init__(*args, **kwargs)
- def _init_layers(self):
- """Initialize decoder layers and other layers."""
- assert self.query_dim in [2, 4], \
- f'{"dab-detr only supports anchor prior or reference point prior"}'
- assert self.query_scale_type in [
- 'cond_elewise', 'cond_scalar', 'fix_elewise'
- ]
- self.layers = ModuleList([
- DABDetrTransformerDecoderLayer(**self.layer_cfg)
- for _ in range(self.num_layers)
- ])
- embed_dims = self.layers[0].embed_dims
- self.embed_dims = embed_dims
- self.post_norm = build_norm_layer(self.post_norm_cfg, embed_dims)[1]
- if self.query_scale_type == 'cond_elewise':
- self.query_scale = MLP(embed_dims, embed_dims, embed_dims, 2)
- elif self.query_scale_type == 'cond_scalar':
- self.query_scale = MLP(embed_dims, embed_dims, 1, 2)
- elif self.query_scale_type == 'fix_elewise':
- self.query_scale = nn.Embedding(self.num_layers, embed_dims)
- else:
- raise NotImplementedError('Unknown query_scale_type: {}'.format(
- self.query_scale_type))
- self.ref_point_head = MLP(self.query_dim // 2 * embed_dims, embed_dims,
- embed_dims, 2)
- if self.with_modulated_hw_attn and self.query_dim == 4:
- self.ref_anchor_head = MLP(embed_dims, embed_dims, 2, 2)
- self.keep_query_pos = self.layers[0].keep_query_pos
- if not self.keep_query_pos:
- for layer_id in range(self.num_layers - 1):
- self.layers[layer_id + 1].cross_attn.qpos_proj = None
- def forward(self,
- query: Tensor,
- key: Tensor,
- query_pos: Tensor,
- key_pos: Tensor,
- reg_branches: nn.Module,
- key_padding_mask: Tensor = None,
- **kwargs) -> List[Tensor]:
- """Forward function of decoder.
- Args:
- query (Tensor): The input query with shape (bs, num_queries, dim).
- key (Tensor): The input key with shape (bs, num_keys, dim).
- query_pos (Tensor): The positional encoding for `query`, with the
- same shape as `query`.
- key_pos (Tensor): The positional encoding for `key`, with the
- same shape as `key`.
- reg_branches (nn.Module): The regression branch for dynamically
- updating references in each layer.
- key_padding_mask (Tensor): ByteTensor with shape (bs, num_keys).
- Defaults to `None`.
- Returns:
- List[Tensor]: forwarded results with shape (num_decoder_layers,
- bs, num_queries, dim) if `return_intermediate` is True, otherwise
- with shape (1, bs, num_queries, dim). references with shape
- (num_decoder_layers, bs, num_queries, 2/4).
- """
- output = query
- unsigmoid_references = query_pos
- reference_points = unsigmoid_references.sigmoid()
- intermediate_reference_points = [reference_points]
- intermediate = []
- for layer_id, layer in enumerate(self.layers):
- obj_center = reference_points[..., :self.query_dim]
- ref_sine_embed = coordinate_to_encoding(
- coord_tensor=obj_center, num_feats=self.embed_dims // 2)
- query_pos = self.ref_point_head(
- ref_sine_embed) # [bs, nq, 2c] -> [bs, nq, c]
- # For the first decoder layer, do not apply transformation
- if self.query_scale_type != 'fix_elewise':
- if layer_id == 0:
- pos_transformation = 1
- else:
- pos_transformation = self.query_scale(output)
- else:
- pos_transformation = self.query_scale.weight[layer_id]
- # apply transformation
- ref_sine_embed = ref_sine_embed[
- ..., :self.embed_dims] * pos_transformation
- # modulated height and weight attention
- if self.with_modulated_hw_attn:
- assert obj_center.size(-1) == 4
- ref_hw = self.ref_anchor_head(output).sigmoid()
- ref_sine_embed[..., self.embed_dims // 2:] *= \
- (ref_hw[..., 0] / obj_center[..., 2]).unsqueeze(-1)
- ref_sine_embed[..., : self.embed_dims // 2] *= \
- (ref_hw[..., 1] / obj_center[..., 3]).unsqueeze(-1)
- output = layer(
- output,
- key,
- query_pos=query_pos,
- ref_sine_embed=ref_sine_embed,
- key_pos=key_pos,
- key_padding_mask=key_padding_mask,
- is_first=(layer_id == 0),
- **kwargs)
- # iter update
- tmp_reg_preds = reg_branches(output)
- tmp_reg_preds[..., :self.query_dim] += inverse_sigmoid(
- reference_points)
- new_reference_points = tmp_reg_preds[
- ..., :self.query_dim].sigmoid()
- if layer_id != self.num_layers - 1:
- intermediate_reference_points.append(new_reference_points)
- reference_points = new_reference_points.detach()
- if self.return_intermediate:
- intermediate.append(self.post_norm(output))
- output = self.post_norm(output)
- if self.return_intermediate:
- return [
- torch.stack(intermediate),
- torch.stack(intermediate_reference_points),
- ]
- else:
- return [
- output.unsqueeze(0),
- torch.stack(intermediate_reference_points)
- ]
- class DABDetrTransformerEncoder(DetrTransformerEncoder):
- """Encoder of DAB-DETR."""
- def _init_layers(self):
- """Initialize encoder layers."""
- self.layers = ModuleList([
- DetrTransformerEncoderLayer(**self.layer_cfg)
- for _ in range(self.num_layers)
- ])
- embed_dims = self.layers[0].embed_dims
- self.embed_dims = embed_dims
- self.query_scale = MLP(embed_dims, embed_dims, embed_dims, 2)
- def forward(self, query: Tensor, query_pos: Tensor,
- key_padding_mask: Tensor, **kwargs):
- """Forward function of encoder.
- Args:
- query (Tensor): Input queries of encoder, has shape
- (bs, num_queries, dim).
- query_pos (Tensor): The positional embeddings of the queries, has
- shape (bs, num_feat_points, dim).
- key_padding_mask (Tensor): ByteTensor, the key padding mask
- of the queries, has shape (bs, num_feat_points).
- Returns:
- Tensor: With shape (num_queries, bs, dim).
- """
- for layer in self.layers:
- pos_scales = self.query_scale(query)
- query = layer(
- query,
- query_pos=query_pos * pos_scales,
- key_padding_mask=key_padding_mask,
- **kwargs)
- return query
|