123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import math
- import warnings
- from typing import Optional, Sequence, Tuple, Union
- import torch
- import torch.nn.functional as F
- from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer,
- build_norm_layer)
- from mmcv.cnn.bricks.drop import Dropout
- from mmengine.model import BaseModule, ModuleList
- from mmengine.utils import to_2tuple
- from torch import Tensor, nn
- from mmdet.registry import MODELS
- from mmdet.utils import OptConfigType, OptMultiConfig
- def nlc_to_nchw(x: Tensor, hw_shape: Sequence[int]) -> Tensor:
- """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
- Args:
- x (Tensor): The input tensor of shape [N, L, C] before conversion.
- hw_shape (Sequence[int]): The height and width of output feature map.
- Returns:
- Tensor: The output tensor of shape [N, C, H, W] after conversion.
- """
- H, W = hw_shape
- assert len(x.shape) == 3
- B, L, C = x.shape
- assert L == H * W, 'The seq_len does not match H, W'
- return x.transpose(1, 2).reshape(B, C, H, W).contiguous()
- def nchw_to_nlc(x):
- """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
- Args:
- x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
- Returns:
- Tensor: The output tensor of shape [N, L, C] after conversion.
- """
- assert len(x.shape) == 4
- return x.flatten(2).transpose(1, 2).contiguous()
- def coordinate_to_encoding(coord_tensor: Tensor,
- num_feats: int = 128,
- temperature: int = 10000,
- scale: float = 2 * math.pi):
- """Convert coordinate tensor to positional encoding.
- Args:
- coord_tensor (Tensor): Coordinate tensor to be converted to
- positional encoding. With the last dimension as 2 or 4.
- num_feats (int, optional): The feature dimension for each position
- along x-axis or y-axis. Note the final returned dimension
- for each position is 2 times of this value. Defaults to 128.
- temperature (int, optional): The temperature used for scaling
- the position embedding. Defaults to 10000.
- scale (float, optional): A scale factor that scales the position
- embedding. The scale will be used only when `normalize` is True.
- Defaults to 2*pi.
- Returns:
- Tensor: Returned encoded positional tensor.
- """
- dim_t = torch.arange(
- num_feats, dtype=torch.float32, device=coord_tensor.device)
- dim_t = temperature**(2 * (dim_t // 2) / num_feats)
- x_embed = coord_tensor[..., 0] * scale
- y_embed = coord_tensor[..., 1] * scale
- pos_x = x_embed[..., None] / dim_t
- pos_y = y_embed[..., None] / dim_t
- pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()),
- dim=-1).flatten(2)
- pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()),
- dim=-1).flatten(2)
- if coord_tensor.size(-1) == 2:
- pos = torch.cat((pos_y, pos_x), dim=-1)
- elif coord_tensor.size(-1) == 4:
- w_embed = coord_tensor[..., 2] * scale
- pos_w = w_embed[..., None] / dim_t
- pos_w = torch.stack((pos_w[..., 0::2].sin(), pos_w[..., 1::2].cos()),
- dim=-1).flatten(2)
- h_embed = coord_tensor[..., 3] * scale
- pos_h = h_embed[..., None] / dim_t
- pos_h = torch.stack((pos_h[..., 0::2].sin(), pos_h[..., 1::2].cos()),
- dim=-1).flatten(2)
- pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=-1)
- else:
- raise ValueError('Unknown pos_tensor shape(-1):{}'.format(
- coord_tensor.size(-1)))
- return pos
- def inverse_sigmoid(x: Tensor, eps: float = 1e-5) -> Tensor:
- """Inverse function of sigmoid.
- Args:
- x (Tensor): The tensor to do the inverse.
- eps (float): EPS avoid numerical overflow. Defaults 1e-5.
- Returns:
- Tensor: The x has passed the inverse function of sigmoid, has the same
- shape with input.
- """
- x = x.clamp(min=0, max=1)
- x1 = x.clamp(min=eps)
- x2 = (1 - x).clamp(min=eps)
- return torch.log(x1 / x2)
- class AdaptivePadding(nn.Module):
- """Applies padding to input (if needed) so that input can get fully covered
- by filter you specified. It support two modes "same" and "corner". The
- "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
- input. The "corner" mode would pad zero to bottom right.
- Args:
- kernel_size (int | tuple): Size of the kernel:
- stride (int | tuple): Stride of the filter. Default: 1:
- dilation (int | tuple): Spacing between kernel elements.
- Default: 1
- padding (str): Support "same" and "corner", "corner" mode
- would pad zero to bottom right, and "same" mode would
- pad zero around input. Default: "corner".
- Example:
- >>> kernel_size = 16
- >>> stride = 16
- >>> dilation = 1
- >>> input = torch.rand(1, 1, 15, 17)
- >>> adap_pad = AdaptivePadding(
- >>> kernel_size=kernel_size,
- >>> stride=stride,
- >>> dilation=dilation,
- >>> padding="corner")
- >>> out = adap_pad(input)
- >>> assert (out.shape[2], out.shape[3]) == (16, 32)
- >>> input = torch.rand(1, 1, 16, 17)
- >>> out = adap_pad(input)
- >>> assert (out.shape[2], out.shape[3]) == (16, 32)
- """
- def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
- super(AdaptivePadding, self).__init__()
- assert padding in ('same', 'corner')
- kernel_size = to_2tuple(kernel_size)
- stride = to_2tuple(stride)
- padding = to_2tuple(padding)
- dilation = to_2tuple(dilation)
- self.padding = padding
- self.kernel_size = kernel_size
- self.stride = stride
- self.dilation = dilation
- def get_pad_shape(self, input_shape):
- input_h, input_w = input_shape
- kernel_h, kernel_w = self.kernel_size
- stride_h, stride_w = self.stride
- output_h = math.ceil(input_h / stride_h)
- output_w = math.ceil(input_w / stride_w)
- pad_h = max((output_h - 1) * stride_h +
- (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
- pad_w = max((output_w - 1) * stride_w +
- (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
- return pad_h, pad_w
- def forward(self, x):
- pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
- if pad_h > 0 or pad_w > 0:
- if self.padding == 'corner':
- x = F.pad(x, [0, pad_w, 0, pad_h])
- elif self.padding == 'same':
- x = F.pad(x, [
- pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
- pad_h - pad_h // 2
- ])
- return x
- class PatchEmbed(BaseModule):
- """Image to Patch Embedding.
- We use a conv layer to implement PatchEmbed.
- Args:
- in_channels (int): The num of input channels. Default: 3
- embed_dims (int): The dimensions of embedding. Default: 768
- conv_type (str): The config dict for embedding
- conv layer type selection. Default: "Conv2d.
- kernel_size (int): The kernel_size of embedding conv. Default: 16.
- stride (int): The slide stride of embedding conv.
- Default: None (Would be set as `kernel_size`).
- padding (int | tuple | string ): The padding length of
- embedding conv. When it is a string, it means the mode
- of adaptive padding, support "same" and "corner" now.
- Default: "corner".
- dilation (int): The dilation rate of embedding conv. Default: 1.
- bias (bool): Bias of embed conv. Default: True.
- norm_cfg (dict, optional): Config dict for normalization layer.
- Default: None.
- input_size (int | tuple | None): The size of input, which will be
- used to calculate the out size. Only work when `dynamic_size`
- is False. Default: None.
- init_cfg (`mmengine.ConfigDict`, optional): The Config for
- initialization. Default: None.
- """
- def __init__(self,
- in_channels: int = 3,
- embed_dims: int = 768,
- conv_type: str = 'Conv2d',
- kernel_size: int = 16,
- stride: int = 16,
- padding: Union[int, tuple, str] = 'corner',
- dilation: int = 1,
- bias: bool = True,
- norm_cfg: OptConfigType = None,
- input_size: Union[int, tuple] = None,
- init_cfg: OptConfigType = None) -> None:
- super(PatchEmbed, self).__init__(init_cfg=init_cfg)
- self.embed_dims = embed_dims
- if stride is None:
- stride = kernel_size
- kernel_size = to_2tuple(kernel_size)
- stride = to_2tuple(stride)
- dilation = to_2tuple(dilation)
- if isinstance(padding, str):
- self.adap_padding = AdaptivePadding(
- kernel_size=kernel_size,
- stride=stride,
- dilation=dilation,
- padding=padding)
- # disable the padding of conv
- padding = 0
- else:
- self.adap_padding = None
- padding = to_2tuple(padding)
- self.projection = build_conv_layer(
- dict(type=conv_type),
- in_channels=in_channels,
- out_channels=embed_dims,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- bias=bias)
- if norm_cfg is not None:
- self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
- else:
- self.norm = None
- if input_size:
- input_size = to_2tuple(input_size)
- # `init_out_size` would be used outside to
- # calculate the num_patches
- # when `use_abs_pos_embed` outside
- self.init_input_size = input_size
- if self.adap_padding:
- pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
- input_h, input_w = input_size
- input_h = input_h + pad_h
- input_w = input_w + pad_w
- input_size = (input_h, input_w)
- # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
- h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
- (kernel_size[0] - 1) - 1) // stride[0] + 1
- w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
- (kernel_size[1] - 1) - 1) // stride[1] + 1
- self.init_out_size = (h_out, w_out)
- else:
- self.init_input_size = None
- self.init_out_size = None
- def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int]]:
- """
- Args:
- x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
- Returns:
- tuple: Contains merged results and its spatial shape.
- - x (Tensor): Has shape (B, out_h * out_w, embed_dims)
- - out_size (tuple[int]): Spatial shape of x, arrange as
- (out_h, out_w).
- """
- if self.adap_padding:
- x = self.adap_padding(x)
- x = self.projection(x)
- out_size = (x.shape[2], x.shape[3])
- x = x.flatten(2).transpose(1, 2)
- if self.norm is not None:
- x = self.norm(x)
- return x, out_size
- class PatchMerging(BaseModule):
- """Merge patch feature map.
- This layer groups feature map by kernel_size, and applies norm and linear
- layers to the grouped feature map. Our implementation uses `nn.Unfold` to
- merge patch, which is about 25% faster than original implementation.
- Instead, we need to modify pretrained models for compatibility.
- Args:
- in_channels (int): The num of input channels.
- to gets fully covered by filter and stride you specified..
- Default: True.
- out_channels (int): The num of output channels.
- kernel_size (int | tuple, optional): the kernel size in the unfold
- layer. Defaults to 2.
- stride (int | tuple, optional): the stride of the sliding blocks in the
- unfold layer. Default: None. (Would be set as `kernel_size`)
- padding (int | tuple | string ): The padding length of
- embedding conv. When it is a string, it means the mode
- of adaptive padding, support "same" and "corner" now.
- Default: "corner".
- dilation (int | tuple, optional): dilation parameter in the unfold
- layer. Default: 1.
- bias (bool, optional): Whether to add bias in linear layer or not.
- Defaults: False.
- norm_cfg (dict, optional): Config dict for normalization layer.
- Default: dict(type='LN').
- init_cfg (dict, optional): The extra config for initialization.
- Default: None.
- """
- def __init__(self,
- in_channels: int,
- out_channels: int,
- kernel_size: Optional[Union[int, tuple]] = 2,
- stride: Optional[Union[int, tuple]] = None,
- padding: Union[int, tuple, str] = 'corner',
- dilation: Optional[Union[int, tuple]] = 1,
- bias: Optional[bool] = False,
- norm_cfg: OptConfigType = dict(type='LN'),
- init_cfg: OptConfigType = None) -> None:
- super().__init__(init_cfg=init_cfg)
- self.in_channels = in_channels
- self.out_channels = out_channels
- if stride:
- stride = stride
- else:
- stride = kernel_size
- kernel_size = to_2tuple(kernel_size)
- stride = to_2tuple(stride)
- dilation = to_2tuple(dilation)
- if isinstance(padding, str):
- self.adap_padding = AdaptivePadding(
- kernel_size=kernel_size,
- stride=stride,
- dilation=dilation,
- padding=padding)
- # disable the padding of unfold
- padding = 0
- else:
- self.adap_padding = None
- padding = to_2tuple(padding)
- self.sampler = nn.Unfold(
- kernel_size=kernel_size,
- dilation=dilation,
- padding=padding,
- stride=stride)
- sample_dim = kernel_size[0] * kernel_size[1] * in_channels
- if norm_cfg is not None:
- self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
- else:
- self.norm = None
- self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
- def forward(self, x: Tensor,
- input_size: Tuple[int]) -> Tuple[Tensor, Tuple[int]]:
- """
- Args:
- x (Tensor): Has shape (B, H*W, C_in).
- input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
- Default: None.
- Returns:
- tuple: Contains merged results and its spatial shape.
- - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
- - out_size (tuple[int]): Spatial shape of x, arrange as
- (Merged_H, Merged_W).
- """
- B, L, C = x.shape
- assert isinstance(input_size, Sequence), f'Expect ' \
- f'input_size is ' \
- f'`Sequence` ' \
- f'but get {input_size}'
- H, W = input_size
- assert L == H * W, 'input feature has wrong size'
- x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
- # Use nn.Unfold to merge patch. About 25% faster than original method,
- # but need to modify pretrained model for compatibility
- if self.adap_padding:
- x = self.adap_padding(x)
- H, W = x.shape[-2:]
- x = self.sampler(x)
- # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
- out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
- (self.sampler.kernel_size[0] - 1) -
- 1) // self.sampler.stride[0] + 1
- out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
- (self.sampler.kernel_size[1] - 1) -
- 1) // self.sampler.stride[1] + 1
- output_size = (out_h, out_w)
- x = x.transpose(1, 2) # B, H/2*W/2, 4*C
- x = self.norm(x) if self.norm else x
- x = self.reduction(x)
- return x, output_size
- class ConditionalAttention(BaseModule):
- """A wrapper of conditional attention, dropout and residual connection.
- Args:
- embed_dims (int): The embedding dimension.
- num_heads (int): Parallel attention heads.
- attn_drop (float): A Dropout layer on attn_output_weights.
- Default: 0.0.
- proj_drop: A Dropout layer after `nn.MultiheadAttention`.
- Default: 0.0.
- cross_attn (bool): Whether the attention module is for cross attention.
- Default: False
- keep_query_pos (bool): Whether to transform query_pos before cross
- attention.
- Default: False.
- batch_first (bool): When it is True, Key, Query and Value are shape of
- (batch, n, embed_dim), otherwise (n, batch, embed_dim).
- Default: True.
- init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
- Default: None.
- """
- def __init__(self,
- embed_dims: int,
- num_heads: int,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- cross_attn: bool = False,
- keep_query_pos: bool = False,
- batch_first: bool = True,
- init_cfg: OptMultiConfig = None):
- super().__init__(init_cfg=init_cfg)
- assert batch_first is True, 'Set `batch_first`\
- to False is NOT supported in ConditionalAttention. \
- First dimension of all DETRs in mmdet is `batch`, \
- please set `batch_first` to True.'
- self.cross_attn = cross_attn
- self.keep_query_pos = keep_query_pos
- self.embed_dims = embed_dims
- self.num_heads = num_heads
- self.attn_drop = Dropout(attn_drop)
- self.proj_drop = Dropout(proj_drop)
- self._init_layers()
- def _init_layers(self):
- """Initialize layers for qkv projection."""
- embed_dims = self.embed_dims
- self.qcontent_proj = Linear(embed_dims, embed_dims)
- self.qpos_proj = Linear(embed_dims, embed_dims)
- self.kcontent_proj = Linear(embed_dims, embed_dims)
- self.kpos_proj = Linear(embed_dims, embed_dims)
- self.v_proj = Linear(embed_dims, embed_dims)
- if self.cross_attn:
- self.qpos_sine_proj = Linear(embed_dims, embed_dims)
- self.out_proj = Linear(embed_dims, embed_dims)
- nn.init.constant_(self.out_proj.bias, 0.)
- def forward_attn(self,
- query: Tensor,
- key: Tensor,
- value: Tensor,
- attn_mask: Tensor = None,
- key_padding_mask: Tensor = None) -> Tuple[Tensor]:
- """Forward process for `ConditionalAttention`.
- Args:
- query (Tensor): The input query with shape [bs, num_queries,
- embed_dims].
- key (Tensor): The key tensor with shape [bs, num_keys,
- embed_dims].
- If None, the `query` will be used. Defaults to None.
- value (Tensor): The value tensor with same shape as `key`.
- Same in `nn.MultiheadAttention.forward`. Defaults to None.
- If None, the `key` will be used.
- attn_mask (Tensor): ByteTensor mask with shape [num_queries,
- num_keys]. Same in `nn.MultiheadAttention.forward`.
- Defaults to None.
- key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
- Defaults to None.
- Returns:
- Tuple[Tensor]: Attention outputs of shape :math:`(N, L, E)`,
- where :math:`N` is the batch size, :math:`L` is the target
- sequence length , and :math:`E` is the embedding dimension
- `embed_dim`. Attention weights per head of shape :math:`
- (num_heads, L, S)`. where :math:`N` is batch size, :math:`L`
- is target sequence length, and :math:`S` is the source sequence
- length.
- """
- assert key.size(1) == value.size(1), \
- f'{"key, value must have the same sequence length"}'
- assert query.size(0) == key.size(0) == value.size(0), \
- f'{"batch size must be equal for query, key, value"}'
- assert query.size(2) == key.size(2), \
- f'{"q_dims, k_dims must be equal"}'
- assert value.size(2) == self.embed_dims, \
- f'{"v_dims must be equal to embed_dims"}'
- bs, tgt_len, hidden_dims = query.size()
- _, src_len, _ = key.size()
- head_dims = hidden_dims // self.num_heads
- v_head_dims = self.embed_dims // self.num_heads
- assert head_dims * self.num_heads == hidden_dims, \
- f'{"hidden_dims must be divisible by num_heads"}'
- scaling = float(head_dims)**-0.5
- q = query * scaling
- k = key
- v = value
- if attn_mask is not None:
- assert attn_mask.dtype == torch.float32 or \
- attn_mask.dtype == torch.float64 or \
- attn_mask.dtype == torch.float16 or \
- attn_mask.dtype == torch.uint8 or \
- attn_mask.dtype == torch.bool, \
- 'Only float, byte, and bool types are supported for \
- attn_mask'
- if attn_mask.dtype == torch.uint8:
- warnings.warn('Byte tensor for attn_mask is deprecated.\
- Use bool tensor instead.')
- attn_mask = attn_mask.to(torch.bool)
- if attn_mask.dim() == 2:
- attn_mask = attn_mask.unsqueeze(0)
- if list(attn_mask.size()) != [1, query.size(1), key.size(1)]:
- raise RuntimeError(
- 'The size of the 2D attn_mask is not correct.')
- elif attn_mask.dim() == 3:
- if list(attn_mask.size()) != [
- bs * self.num_heads,
- query.size(1),
- key.size(1)
- ]:
- raise RuntimeError(
- 'The size of the 3D attn_mask is not correct.')
- else:
- raise RuntimeError(
- "attn_mask's dimension {} is not supported".format(
- attn_mask.dim()))
- # attn_mask's dim is 3 now.
- if key_padding_mask is not None and key_padding_mask.dtype == int:
- key_padding_mask = key_padding_mask.to(torch.bool)
- q = q.contiguous().view(bs, tgt_len, self.num_heads,
- head_dims).permute(0, 2, 1, 3).flatten(0, 1)
- if k is not None:
- k = k.contiguous().view(bs, src_len, self.num_heads,
- head_dims).permute(0, 2, 1,
- 3).flatten(0, 1)
- if v is not None:
- v = v.contiguous().view(bs, src_len, self.num_heads,
- v_head_dims).permute(0, 2, 1,
- 3).flatten(0, 1)
- if key_padding_mask is not None:
- assert key_padding_mask.size(0) == bs
- assert key_padding_mask.size(1) == src_len
- attn_output_weights = torch.bmm(q, k.transpose(1, 2))
- assert list(attn_output_weights.size()) == [
- bs * self.num_heads, tgt_len, src_len
- ]
- if attn_mask is not None:
- if attn_mask.dtype == torch.bool:
- attn_output_weights.masked_fill_(attn_mask, float('-inf'))
- else:
- attn_output_weights += attn_mask
- if key_padding_mask is not None:
- attn_output_weights = attn_output_weights.view(
- bs, self.num_heads, tgt_len, src_len)
- attn_output_weights = attn_output_weights.masked_fill(
- key_padding_mask.unsqueeze(1).unsqueeze(2),
- float('-inf'),
- )
- attn_output_weights = attn_output_weights.view(
- bs * self.num_heads, tgt_len, src_len)
- attn_output_weights = F.softmax(
- attn_output_weights -
- attn_output_weights.max(dim=-1, keepdim=True)[0],
- dim=-1)
- attn_output_weights = self.attn_drop(attn_output_weights)
- attn_output = torch.bmm(attn_output_weights, v)
- assert list(
- attn_output.size()) == [bs * self.num_heads, tgt_len, v_head_dims]
- attn_output = attn_output.view(bs, self.num_heads, tgt_len,
- v_head_dims).permute(0, 2, 1,
- 3).flatten(2)
- attn_output = self.out_proj(attn_output)
- # average attention weights over heads
- attn_output_weights = attn_output_weights.view(bs, self.num_heads,
- tgt_len, src_len)
- return attn_output, attn_output_weights.sum(dim=1) / self.num_heads
- def forward(self,
- query: Tensor,
- key: Tensor,
- query_pos: Tensor = None,
- ref_sine_embed: Tensor = None,
- key_pos: Tensor = None,
- attn_mask: Tensor = None,
- key_padding_mask: Tensor = None,
- is_first: bool = False) -> Tensor:
- """Forward function for `ConditionalAttention`.
- Args:
- query (Tensor): The input query with shape [bs, num_queries,
- embed_dims].
- key (Tensor): The key tensor with shape [bs, num_keys,
- embed_dims].
- If None, the `query` will be used. Defaults to None.
- query_pos (Tensor): The positional encoding for query in self
- attention, with the same shape as `x`. If not None, it will
- be added to `x` before forward function.
- Defaults to None.
- query_sine_embed (Tensor): The positional encoding for query in
- cross attention, with the same shape as `x`. If not None, it
- will be added to `x` before forward function.
- Defaults to None.
- key_pos (Tensor): The positional encoding for `key`, with the
- same shape as `key`. Defaults to None. If not None, it will
- be added to `key` before forward function. If None, and
- `query_pos` has the same shape as `key`, then `query_pos`
- will be used for `key_pos`. Defaults to None.
- attn_mask (Tensor): ByteTensor mask with shape [num_queries,
- num_keys]. Same in `nn.MultiheadAttention.forward`.
- Defaults to None.
- key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
- Defaults to None.
- is_first (bool): A indicator to tell whether the current layer
- is the first layer of the decoder.
- Defaults to False.
- Returns:
- Tensor: forwarded results with shape
- [bs, num_queries, embed_dims].
- """
- if self.cross_attn:
- q_content = self.qcontent_proj(query)
- k_content = self.kcontent_proj(key)
- v = self.v_proj(key)
- bs, nq, c = q_content.size()
- _, hw, _ = k_content.size()
- k_pos = self.kpos_proj(key_pos)
- if is_first or self.keep_query_pos:
- q_pos = self.qpos_proj(query_pos)
- q = q_content + q_pos
- k = k_content + k_pos
- else:
- q = q_content
- k = k_content
- q = q.view(bs, nq, self.num_heads, c // self.num_heads)
- query_sine_embed = self.qpos_sine_proj(ref_sine_embed)
- query_sine_embed = query_sine_embed.view(bs, nq, self.num_heads,
- c // self.num_heads)
- q = torch.cat([q, query_sine_embed], dim=3).view(bs, nq, 2 * c)
- k = k.view(bs, hw, self.num_heads, c // self.num_heads)
- k_pos = k_pos.view(bs, hw, self.num_heads, c // self.num_heads)
- k = torch.cat([k, k_pos], dim=3).view(bs, hw, 2 * c)
- ca_output = self.forward_attn(
- query=q,
- key=k,
- value=v,
- attn_mask=attn_mask,
- key_padding_mask=key_padding_mask)[0]
- query = query + self.proj_drop(ca_output)
- else:
- q_content = self.qcontent_proj(query)
- q_pos = self.qpos_proj(query_pos)
- k_content = self.kcontent_proj(query)
- k_pos = self.kpos_proj(query_pos)
- v = self.v_proj(query)
- q = q_content if q_pos is None else q_content + q_pos
- k = k_content if k_pos is None else k_content + k_pos
- sa_output = self.forward_attn(
- query=q,
- key=k,
- value=v,
- attn_mask=attn_mask,
- key_padding_mask=key_padding_mask)[0]
- query = query + self.proj_drop(sa_output)
- return query
- class MLP(BaseModule):
- """Very simple multi-layer perceptron (also called FFN) with relu. Mostly
- used in DETR series detectors.
- Args:
- input_dim (int): Feature dim of the input tensor.
- hidden_dim (int): Feature dim of the hidden layer.
- output_dim (int): Feature dim of the output tensor.
- num_layers (int): Number of FFN layers. As the last
- layer of MLP only contains FFN (Linear).
- """
- def __init__(self, input_dim: int, hidden_dim: int, output_dim: int,
- num_layers: int) -> None:
- super().__init__()
- self.num_layers = num_layers
- h = [hidden_dim] * (num_layers - 1)
- self.layers = ModuleList(
- Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
- def forward(self, x: Tensor) -> Tensor:
- """Forward function of MLP.
- Args:
- x (Tensor): The input feature, has shape
- (num_queries, bs, input_dim).
- Returns:
- Tensor: The output feature, has shape
- (num_queries, bs, output_dim).
- """
- for i, layer in enumerate(self.layers):
- x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
- return x
- @MODELS.register_module()
- class DynamicConv(BaseModule):
- """Implements Dynamic Convolution.
- This module generate parameters for each sample and
- use bmm to implement 1*1 convolution. Code is modified
- from the `official github repo <https://github.com/PeizeSun/
- SparseR-CNN/blob/main/projects/SparseRCNN/sparsercnn/head.py#L258>`_ .
- Args:
- in_channels (int): The input feature channel.
- Defaults to 256.
- feat_channels (int): The inner feature channel.
- Defaults to 64.
- out_channels (int, optional): The output feature channel.
- When not specified, it will be set to `in_channels`
- by default
- input_feat_shape (int): The shape of input feature.
- Defaults to 7.
- with_proj (bool): Project two-dimentional feature to
- one-dimentional feature. Default to True.
- act_cfg (dict): The activation config for DynamicConv.
- norm_cfg (dict): Config dict for normalization layer. Default
- layer normalization.
- init_cfg (obj:`mmengine.ConfigDict`): The Config for initialization.
- Default: None.
- """
- def __init__(self,
- in_channels: int = 256,
- feat_channels: int = 64,
- out_channels: Optional[int] = None,
- input_feat_shape: int = 7,
- with_proj: bool = True,
- act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
- norm_cfg: OptConfigType = dict(type='LN'),
- init_cfg: OptConfigType = None) -> None:
- super(DynamicConv, self).__init__(init_cfg)
- self.in_channels = in_channels
- self.feat_channels = feat_channels
- self.out_channels_raw = out_channels
- self.input_feat_shape = input_feat_shape
- self.with_proj = with_proj
- self.act_cfg = act_cfg
- self.norm_cfg = norm_cfg
- self.out_channels = out_channels if out_channels else in_channels
- self.num_params_in = self.in_channels * self.feat_channels
- self.num_params_out = self.out_channels * self.feat_channels
- self.dynamic_layer = nn.Linear(
- self.in_channels, self.num_params_in + self.num_params_out)
- self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
- self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1]
- self.activation = build_activation_layer(act_cfg)
- num_output = self.out_channels * input_feat_shape**2
- if self.with_proj:
- self.fc_layer = nn.Linear(num_output, self.out_channels)
- self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
- def forward(self, param_feature: Tensor, input_feature: Tensor) -> Tensor:
- """Forward function for `DynamicConv`.
- Args:
- param_feature (Tensor): The feature can be used
- to generate the parameter, has shape
- (num_all_proposals, in_channels).
- input_feature (Tensor): Feature that
- interact with parameters, has shape
- (num_all_proposals, in_channels, H, W).
- Returns:
- Tensor: The output feature has shape
- (num_all_proposals, out_channels).
- """
- input_feature = input_feature.flatten(2).permute(2, 0, 1)
- input_feature = input_feature.permute(1, 0, 2)
- parameters = self.dynamic_layer(param_feature)
- param_in = parameters[:, :self.num_params_in].view(
- -1, self.in_channels, self.feat_channels)
- param_out = parameters[:, -self.num_params_out:].view(
- -1, self.feat_channels, self.out_channels)
- # input_feature has shape (num_all_proposals, H*W, in_channels)
- # param_in has shape (num_all_proposals, in_channels, feat_channels)
- # feature has shape (num_all_proposals, H*W, feat_channels)
- features = torch.bmm(input_feature, param_in)
- features = self.norm_in(features)
- features = self.activation(features)
- # param_out has shape (batch_size, feat_channels, out_channels)
- features = torch.bmm(features, param_out)
- features = self.norm_out(features)
- features = self.activation(features)
- if self.with_proj:
- features = features.flatten(1)
- features = self.fc_layer(features)
- features = self.fc_norm(features)
- features = self.activation(features)
- return features
|