123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch
- from mmcv.cnn import build_norm_layer
- from mmcv.cnn.bricks.transformer import FFN
- from torch import Tensor
- from torch.nn import ModuleList
- from .detr_layers import DetrTransformerDecoder, DetrTransformerDecoderLayer
- from .utils import MLP, ConditionalAttention, coordinate_to_encoding
- class ConditionalDetrTransformerDecoder(DetrTransformerDecoder):
- """Decoder of Conditional DETR."""
- def _init_layers(self) -> None:
- """Initialize decoder layers and other layers."""
- self.layers = ModuleList([
- ConditionalDetrTransformerDecoderLayer(**self.layer_cfg)
- for _ in range(self.num_layers)
- ])
- self.embed_dims = self.layers[0].embed_dims
- self.post_norm = build_norm_layer(self.post_norm_cfg,
- self.embed_dims)[1]
- # conditional detr affline
- self.query_scale = MLP(self.embed_dims, self.embed_dims,
- self.embed_dims, 2)
- self.ref_point_head = MLP(self.embed_dims, self.embed_dims, 2, 2)
- # we have substitute 'qpos_proj' with 'qpos_sine_proj' except for
- # the first decoder layer), so 'qpos_proj' should be deleted
- # in other layers.
- for layer_id in range(self.num_layers - 1):
- self.layers[layer_id + 1].cross_attn.qpos_proj = None
- def forward(self,
- query: Tensor,
- key: Tensor = None,
- query_pos: Tensor = None,
- key_pos: Tensor = None,
- key_padding_mask: Tensor = None):
- """Forward function of decoder.
- Args:
- query (Tensor): The input query with shape
- (bs, num_queries, dim).
- key (Tensor): The input key with shape (bs, num_keys, dim) If
- `None`, the `query` will be used. Defaults to `None`.
- query_pos (Tensor): The positional encoding for `query`, with the
- same shape as `query`. If not `None`, it will be added to
- `query` before forward function. Defaults to `None`.
- key_pos (Tensor): The positional encoding for `key`, with the
- same shape as `key`. If not `None`, it will be added to
- `key` before forward function. If `None`, and `query_pos`
- has the same shape as `key`, then `query_pos` will be used
- as `key_pos`. Defaults to `None`.
- key_padding_mask (Tensor): ByteTensor with shape (bs, num_keys).
- Defaults to `None`.
- Returns:
- List[Tensor]: forwarded results with shape (num_decoder_layers,
- bs, num_queries, dim) if `return_intermediate` is True, otherwise
- with shape (1, bs, num_queries, dim). References with shape
- (bs, num_queries, 2).
- """
- reference_unsigmoid = self.ref_point_head(
- query_pos) # [bs, num_queries, 2]
- reference = reference_unsigmoid.sigmoid()
- reference_xy = reference[..., :2]
- intermediate = []
- for layer_id, layer in enumerate(self.layers):
- if layer_id == 0:
- pos_transformation = 1
- else:
- pos_transformation = self.query_scale(query)
- # get sine embedding for the query reference
- ref_sine_embed = coordinate_to_encoding(coord_tensor=reference_xy)
- # apply transformation
- ref_sine_embed = ref_sine_embed * pos_transformation
- query = layer(
- query,
- key=key,
- query_pos=query_pos,
- key_pos=key_pos,
- key_padding_mask=key_padding_mask,
- ref_sine_embed=ref_sine_embed,
- is_first=(layer_id == 0))
- if self.return_intermediate:
- intermediate.append(self.post_norm(query))
- if self.return_intermediate:
- return torch.stack(intermediate), reference
- query = self.post_norm(query)
- return query.unsqueeze(0), reference
- class ConditionalDetrTransformerDecoderLayer(DetrTransformerDecoderLayer):
- """Implements decoder layer in Conditional DETR transformer."""
- def _init_layers(self):
- """Initialize self-attention, cross-attention, FFN, and
- normalization."""
- self.self_attn = ConditionalAttention(**self.self_attn_cfg)
- self.cross_attn = ConditionalAttention(**self.cross_attn_cfg)
- self.embed_dims = self.self_attn.embed_dims
- self.ffn = FFN(**self.ffn_cfg)
- norms_list = [
- build_norm_layer(self.norm_cfg, self.embed_dims)[1]
- for _ in range(3)
- ]
- self.norms = ModuleList(norms_list)
- def forward(self,
- query: Tensor,
- key: Tensor = None,
- query_pos: Tensor = None,
- key_pos: Tensor = None,
- self_attn_masks: Tensor = None,
- cross_attn_masks: Tensor = None,
- key_padding_mask: Tensor = None,
- ref_sine_embed: Tensor = None,
- is_first: bool = False):
- """
- Args:
- query (Tensor): The input query, has shape (bs, num_queries, dim)
- key (Tensor, optional): The input key, has shape (bs, num_keys,
- dim). If `None`, the `query` will be used. Defaults to `None`.
- query_pos (Tensor, optional): The positional encoding for `query`,
- has the same shape as `query`. If not `None`, it will be
- added to `query` before forward function. Defaults to `None`.
- ref_sine_embed (Tensor): The positional encoding for query in
- cross attention, with the same shape as `x`. Defaults to None.
- key_pos (Tensor, optional): The positional encoding for `key`, has
- the same shape as `key`. If not None, it will be added to
- `key` before forward function. If None, and `query_pos` has
- the same shape as `key`, then `query_pos` will be used for
- `key_pos`. Defaults to None.
- self_attn_masks (Tensor, optional): ByteTensor mask, has shape
- (num_queries, num_keys), Same in `nn.MultiheadAttention.
- forward`. Defaults to None.
- cross_attn_masks (Tensor, optional): ByteTensor mask, has shape
- (num_queries, num_keys), Same in `nn.MultiheadAttention.
- forward`. Defaults to None.
- key_padding_mask (Tensor, optional): ByteTensor, has shape
- (bs, num_keys). Defaults to None.
- is_first (bool): A indicator to tell whether the current layer
- is the first layer of the decoder. Defaults to False.
- Returns:
- Tensor: Forwarded results, has shape (bs, num_queries, dim).
- """
- query = self.self_attn(
- query=query,
- key=query,
- query_pos=query_pos,
- key_pos=query_pos,
- attn_mask=self_attn_masks)
- query = self.norms[0](query)
- query = self.cross_attn(
- query=query,
- key=key,
- query_pos=query_pos,
- key_pos=key_pos,
- attn_mask=cross_attn_masks,
- key_padding_mask=key_padding_mask,
- ref_sine_embed=ref_sine_embed,
- is_first=is_first)
- query = self.norms[1](query)
- query = self.ffn(query)
- query = self.norms[2](query)
- return query
|