# Copyright (c) OpenMMLab. All rights reserved. import math import warnings from typing import Optional, Sequence, Tuple, Union import torch import torch.nn.functional as F from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer, build_norm_layer) from mmcv.cnn.bricks.drop import Dropout from mmengine.model import BaseModule, ModuleList from mmengine.utils import to_2tuple from torch import Tensor, nn from mmdet.registry import MODELS from mmdet.utils import OptConfigType, OptMultiConfig def nlc_to_nchw(x: Tensor, hw_shape: Sequence[int]) -> Tensor: """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. Args: x (Tensor): The input tensor of shape [N, L, C] before conversion. hw_shape (Sequence[int]): The height and width of output feature map. Returns: Tensor: The output tensor of shape [N, C, H, W] after conversion. """ H, W = hw_shape assert len(x.shape) == 3 B, L, C = x.shape assert L == H * W, 'The seq_len does not match H, W' return x.transpose(1, 2).reshape(B, C, H, W).contiguous() def nchw_to_nlc(x): """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. Args: x (Tensor): The input tensor of shape [N, C, H, W] before conversion. Returns: Tensor: The output tensor of shape [N, L, C] after conversion. """ assert len(x.shape) == 4 return x.flatten(2).transpose(1, 2).contiguous() def coordinate_to_encoding(coord_tensor: Tensor, num_feats: int = 128, temperature: int = 10000, scale: float = 2 * math.pi): """Convert coordinate tensor to positional encoding. Args: coord_tensor (Tensor): Coordinate tensor to be converted to positional encoding. With the last dimension as 2 or 4. num_feats (int, optional): The feature dimension for each position along x-axis or y-axis. Note the final returned dimension for each position is 2 times of this value. Defaults to 128. temperature (int, optional): The temperature used for scaling the position embedding. Defaults to 10000. scale (float, optional): A scale factor that scales the position embedding. The scale will be used only when `normalize` is True. Defaults to 2*pi. Returns: Tensor: Returned encoded positional tensor. """ dim_t = torch.arange( num_feats, dtype=torch.float32, device=coord_tensor.device) dim_t = temperature**(2 * (dim_t // 2) / num_feats) x_embed = coord_tensor[..., 0] * scale y_embed = coord_tensor[..., 1] * scale pos_x = x_embed[..., None] / dim_t pos_y = y_embed[..., None] / dim_t pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(2) pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(2) if coord_tensor.size(-1) == 2: pos = torch.cat((pos_y, pos_x), dim=-1) elif coord_tensor.size(-1) == 4: w_embed = coord_tensor[..., 2] * scale pos_w = w_embed[..., None] / dim_t pos_w = torch.stack((pos_w[..., 0::2].sin(), pos_w[..., 1::2].cos()), dim=-1).flatten(2) h_embed = coord_tensor[..., 3] * scale pos_h = h_embed[..., None] / dim_t pos_h = torch.stack((pos_h[..., 0::2].sin(), pos_h[..., 1::2].cos()), dim=-1).flatten(2) pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=-1) else: raise ValueError('Unknown pos_tensor shape(-1):{}'.format( coord_tensor.size(-1))) return pos def inverse_sigmoid(x: Tensor, eps: float = 1e-5) -> Tensor: """Inverse function of sigmoid. Args: x (Tensor): The tensor to do the inverse. eps (float): EPS avoid numerical overflow. Defaults 1e-5. Returns: Tensor: The x has passed the inverse function of sigmoid, has the same shape with input. """ x = x.clamp(min=0, max=1) x1 = x.clamp(min=eps) x2 = (1 - x).clamp(min=eps) return torch.log(x1 / x2) class AdaptivePadding(nn.Module): """Applies padding to input (if needed) so that input can get fully covered by filter you specified. It support two modes "same" and "corner". The "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around input. The "corner" mode would pad zero to bottom right. Args: kernel_size (int | tuple): Size of the kernel: stride (int | tuple): Stride of the filter. Default: 1: dilation (int | tuple): Spacing between kernel elements. Default: 1 padding (str): Support "same" and "corner", "corner" mode would pad zero to bottom right, and "same" mode would pad zero around input. Default: "corner". Example: >>> kernel_size = 16 >>> stride = 16 >>> dilation = 1 >>> input = torch.rand(1, 1, 15, 17) >>> adap_pad = AdaptivePadding( >>> kernel_size=kernel_size, >>> stride=stride, >>> dilation=dilation, >>> padding="corner") >>> out = adap_pad(input) >>> assert (out.shape[2], out.shape[3]) == (16, 32) >>> input = torch.rand(1, 1, 16, 17) >>> out = adap_pad(input) >>> assert (out.shape[2], out.shape[3]) == (16, 32) """ def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'): super(AdaptivePadding, self).__init__() assert padding in ('same', 'corner') kernel_size = to_2tuple(kernel_size) stride = to_2tuple(stride) padding = to_2tuple(padding) dilation = to_2tuple(dilation) self.padding = padding self.kernel_size = kernel_size self.stride = stride self.dilation = dilation def get_pad_shape(self, input_shape): input_h, input_w = input_shape kernel_h, kernel_w = self.kernel_size stride_h, stride_w = self.stride output_h = math.ceil(input_h / stride_h) output_w = math.ceil(input_w / stride_w) pad_h = max((output_h - 1) * stride_h + (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0) pad_w = max((output_w - 1) * stride_w + (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0) return pad_h, pad_w def forward(self, x): pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) if pad_h > 0 or pad_w > 0: if self.padding == 'corner': x = F.pad(x, [0, pad_w, 0, pad_h]) elif self.padding == 'same': x = F.pad(x, [ pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2 ]) return x class PatchEmbed(BaseModule): """Image to Patch Embedding. We use a conv layer to implement PatchEmbed. Args: in_channels (int): The num of input channels. Default: 3 embed_dims (int): The dimensions of embedding. Default: 768 conv_type (str): The config dict for embedding conv layer type selection. Default: "Conv2d. kernel_size (int): The kernel_size of embedding conv. Default: 16. stride (int): The slide stride of embedding conv. Default: None (Would be set as `kernel_size`). padding (int | tuple | string ): The padding length of embedding conv. When it is a string, it means the mode of adaptive padding, support "same" and "corner" now. Default: "corner". dilation (int): The dilation rate of embedding conv. Default: 1. bias (bool): Bias of embed conv. Default: True. norm_cfg (dict, optional): Config dict for normalization layer. Default: None. input_size (int | tuple | None): The size of input, which will be used to calculate the out size. Only work when `dynamic_size` is False. Default: None. init_cfg (`mmengine.ConfigDict`, optional): The Config for initialization. Default: None. """ def __init__(self, in_channels: int = 3, embed_dims: int = 768, conv_type: str = 'Conv2d', kernel_size: int = 16, stride: int = 16, padding: Union[int, tuple, str] = 'corner', dilation: int = 1, bias: bool = True, norm_cfg: OptConfigType = None, input_size: Union[int, tuple] = None, init_cfg: OptConfigType = None) -> None: super(PatchEmbed, self).__init__(init_cfg=init_cfg) self.embed_dims = embed_dims if stride is None: stride = kernel_size kernel_size = to_2tuple(kernel_size) stride = to_2tuple(stride) dilation = to_2tuple(dilation) if isinstance(padding, str): self.adap_padding = AdaptivePadding( kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding) # disable the padding of conv padding = 0 else: self.adap_padding = None padding = to_2tuple(padding) self.projection = build_conv_layer( dict(type=conv_type), in_channels=in_channels, out_channels=embed_dims, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) if norm_cfg is not None: self.norm = build_norm_layer(norm_cfg, embed_dims)[1] else: self.norm = None if input_size: input_size = to_2tuple(input_size) # `init_out_size` would be used outside to # calculate the num_patches # when `use_abs_pos_embed` outside self.init_input_size = input_size if self.adap_padding: pad_h, pad_w = self.adap_padding.get_pad_shape(input_size) input_h, input_w = input_size input_h = input_h + pad_h input_w = input_w + pad_w input_size = (input_h, input_w) # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html h_out = (input_size[0] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0] + 1 w_out = (input_size[1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1 self.init_out_size = (h_out, w_out) else: self.init_input_size = None self.init_out_size = None def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int]]: """ Args: x (Tensor): Has shape (B, C, H, W). In most case, C is 3. Returns: tuple: Contains merged results and its spatial shape. - x (Tensor): Has shape (B, out_h * out_w, embed_dims) - out_size (tuple[int]): Spatial shape of x, arrange as (out_h, out_w). """ if self.adap_padding: x = self.adap_padding(x) x = self.projection(x) out_size = (x.shape[2], x.shape[3]) x = x.flatten(2).transpose(1, 2) if self.norm is not None: x = self.norm(x) return x, out_size class PatchMerging(BaseModule): """Merge patch feature map. This layer groups feature map by kernel_size, and applies norm and linear layers to the grouped feature map. Our implementation uses `nn.Unfold` to merge patch, which is about 25% faster than original implementation. Instead, we need to modify pretrained models for compatibility. Args: in_channels (int): The num of input channels. to gets fully covered by filter and stride you specified.. Default: True. out_channels (int): The num of output channels. kernel_size (int | tuple, optional): the kernel size in the unfold layer. Defaults to 2. stride (int | tuple, optional): the stride of the sliding blocks in the unfold layer. Default: None. (Would be set as `kernel_size`) padding (int | tuple | string ): The padding length of embedding conv. When it is a string, it means the mode of adaptive padding, support "same" and "corner" now. Default: "corner". dilation (int | tuple, optional): dilation parameter in the unfold layer. Default: 1. bias (bool, optional): Whether to add bias in linear layer or not. Defaults: False. norm_cfg (dict, optional): Config dict for normalization layer. Default: dict(type='LN'). init_cfg (dict, optional): The extra config for initialization. Default: None. """ def __init__(self, in_channels: int, out_channels: int, kernel_size: Optional[Union[int, tuple]] = 2, stride: Optional[Union[int, tuple]] = None, padding: Union[int, tuple, str] = 'corner', dilation: Optional[Union[int, tuple]] = 1, bias: Optional[bool] = False, norm_cfg: OptConfigType = dict(type='LN'), init_cfg: OptConfigType = None) -> None: super().__init__(init_cfg=init_cfg) self.in_channels = in_channels self.out_channels = out_channels if stride: stride = stride else: stride = kernel_size kernel_size = to_2tuple(kernel_size) stride = to_2tuple(stride) dilation = to_2tuple(dilation) if isinstance(padding, str): self.adap_padding = AdaptivePadding( kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding) # disable the padding of unfold padding = 0 else: self.adap_padding = None padding = to_2tuple(padding) self.sampler = nn.Unfold( kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride) sample_dim = kernel_size[0] * kernel_size[1] * in_channels if norm_cfg is not None: self.norm = build_norm_layer(norm_cfg, sample_dim)[1] else: self.norm = None self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) def forward(self, x: Tensor, input_size: Tuple[int]) -> Tuple[Tensor, Tuple[int]]: """ Args: x (Tensor): Has shape (B, H*W, C_in). input_size (tuple[int]): The spatial shape of x, arrange as (H, W). Default: None. Returns: tuple: Contains merged results and its spatial shape. - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) - out_size (tuple[int]): Spatial shape of x, arrange as (Merged_H, Merged_W). """ B, L, C = x.shape assert isinstance(input_size, Sequence), f'Expect ' \ f'input_size is ' \ f'`Sequence` ' \ f'but get {input_size}' H, W = input_size assert L == H * W, 'input feature has wrong size' x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W # Use nn.Unfold to merge patch. About 25% faster than original method, # but need to modify pretrained model for compatibility if self.adap_padding: x = self.adap_padding(x) H, W = x.shape[-2:] x = self.sampler(x) # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * (self.sampler.kernel_size[0] - 1) - 1) // self.sampler.stride[0] + 1 out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * (self.sampler.kernel_size[1] - 1) - 1) // self.sampler.stride[1] + 1 output_size = (out_h, out_w) x = x.transpose(1, 2) # B, H/2*W/2, 4*C x = self.norm(x) if self.norm else x x = self.reduction(x) return x, output_size class ConditionalAttention(BaseModule): """A wrapper of conditional attention, dropout and residual connection. Args: embed_dims (int): The embedding dimension. num_heads (int): Parallel attention heads. attn_drop (float): A Dropout layer on attn_output_weights. Default: 0.0. proj_drop: A Dropout layer after `nn.MultiheadAttention`. Default: 0.0. cross_attn (bool): Whether the attention module is for cross attention. Default: False keep_query_pos (bool): Whether to transform query_pos before cross attention. Default: False. batch_first (bool): When it is True, Key, Query and Value are shape of (batch, n, embed_dim), otherwise (n, batch, embed_dim). Default: True. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. """ def __init__(self, embed_dims: int, num_heads: int, attn_drop: float = 0., proj_drop: float = 0., cross_attn: bool = False, keep_query_pos: bool = False, batch_first: bool = True, init_cfg: OptMultiConfig = None): super().__init__(init_cfg=init_cfg) assert batch_first is True, 'Set `batch_first`\ to False is NOT supported in ConditionalAttention. \ First dimension of all DETRs in mmdet is `batch`, \ please set `batch_first` to True.' self.cross_attn = cross_attn self.keep_query_pos = keep_query_pos self.embed_dims = embed_dims self.num_heads = num_heads self.attn_drop = Dropout(attn_drop) self.proj_drop = Dropout(proj_drop) self._init_layers() def _init_layers(self): """Initialize layers for qkv projection.""" embed_dims = self.embed_dims self.qcontent_proj = Linear(embed_dims, embed_dims) self.qpos_proj = Linear(embed_dims, embed_dims) self.kcontent_proj = Linear(embed_dims, embed_dims) self.kpos_proj = Linear(embed_dims, embed_dims) self.v_proj = Linear(embed_dims, embed_dims) if self.cross_attn: self.qpos_sine_proj = Linear(embed_dims, embed_dims) self.out_proj = Linear(embed_dims, embed_dims) nn.init.constant_(self.out_proj.bias, 0.) def forward_attn(self, query: Tensor, key: Tensor, value: Tensor, attn_mask: Tensor = None, key_padding_mask: Tensor = None) -> Tuple[Tensor]: """Forward process for `ConditionalAttention`. Args: query (Tensor): The input query with shape [bs, num_queries, embed_dims]. key (Tensor): The key tensor with shape [bs, num_keys, embed_dims]. If None, the `query` will be used. Defaults to None. value (Tensor): The value tensor with same shape as `key`. Same in `nn.MultiheadAttention.forward`. Defaults to None. If None, the `key` will be used. attn_mask (Tensor): ByteTensor mask with shape [num_queries, num_keys]. Same in `nn.MultiheadAttention.forward`. Defaults to None. key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. Defaults to None. Returns: Tuple[Tensor]: Attention outputs of shape :math:`(N, L, E)`, where :math:`N` is the batch size, :math:`L` is the target sequence length , and :math:`E` is the embedding dimension `embed_dim`. Attention weights per head of shape :math:` (num_heads, L, S)`. where :math:`N` is batch size, :math:`L` is target sequence length, and :math:`S` is the source sequence length. """ assert key.size(1) == value.size(1), \ f'{"key, value must have the same sequence length"}' assert query.size(0) == key.size(0) == value.size(0), \ f'{"batch size must be equal for query, key, value"}' assert query.size(2) == key.size(2), \ f'{"q_dims, k_dims must be equal"}' assert value.size(2) == self.embed_dims, \ f'{"v_dims must be equal to embed_dims"}' bs, tgt_len, hidden_dims = query.size() _, src_len, _ = key.size() head_dims = hidden_dims // self.num_heads v_head_dims = self.embed_dims // self.num_heads assert head_dims * self.num_heads == hidden_dims, \ f'{"hidden_dims must be divisible by num_heads"}' scaling = float(head_dims)**-0.5 q = query * scaling k = key v = value if attn_mask is not None: assert attn_mask.dtype == torch.float32 or \ attn_mask.dtype == torch.float64 or \ attn_mask.dtype == torch.float16 or \ attn_mask.dtype == torch.uint8 or \ attn_mask.dtype == torch.bool, \ 'Only float, byte, and bool types are supported for \ attn_mask' if attn_mask.dtype == torch.uint8: warnings.warn('Byte tensor for attn_mask is deprecated.\ Use bool tensor instead.') attn_mask = attn_mask.to(torch.bool) if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(1), key.size(1)]: raise RuntimeError( 'The size of the 2D attn_mask is not correct.') elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bs * self.num_heads, query.size(1), key.size(1) ]: raise RuntimeError( 'The size of the 3D attn_mask is not correct.') else: raise RuntimeError( "attn_mask's dimension {} is not supported".format( attn_mask.dim())) # attn_mask's dim is 3 now. if key_padding_mask is not None and key_padding_mask.dtype == int: key_padding_mask = key_padding_mask.to(torch.bool) q = q.contiguous().view(bs, tgt_len, self.num_heads, head_dims).permute(0, 2, 1, 3).flatten(0, 1) if k is not None: k = k.contiguous().view(bs, src_len, self.num_heads, head_dims).permute(0, 2, 1, 3).flatten(0, 1) if v is not None: v = v.contiguous().view(bs, src_len, self.num_heads, v_head_dims).permute(0, 2, 1, 3).flatten(0, 1) if key_padding_mask is not None: assert key_padding_mask.size(0) == bs assert key_padding_mask.size(1) == src_len attn_output_weights = torch.bmm(q, k.transpose(1, 2)) assert list(attn_output_weights.size()) == [ bs * self.num_heads, tgt_len, src_len ] if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_output_weights.masked_fill_(attn_mask, float('-inf')) else: attn_output_weights += attn_mask if key_padding_mask is not None: attn_output_weights = attn_output_weights.view( bs, self.num_heads, tgt_len, src_len) attn_output_weights = attn_output_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'), ) attn_output_weights = attn_output_weights.view( bs * self.num_heads, tgt_len, src_len) attn_output_weights = F.softmax( attn_output_weights - attn_output_weights.max(dim=-1, keepdim=True)[0], dim=-1) attn_output_weights = self.attn_drop(attn_output_weights) attn_output = torch.bmm(attn_output_weights, v) assert list( attn_output.size()) == [bs * self.num_heads, tgt_len, v_head_dims] attn_output = attn_output.view(bs, self.num_heads, tgt_len, v_head_dims).permute(0, 2, 1, 3).flatten(2) attn_output = self.out_proj(attn_output) # average attention weights over heads attn_output_weights = attn_output_weights.view(bs, self.num_heads, tgt_len, src_len) return attn_output, attn_output_weights.sum(dim=1) / self.num_heads def forward(self, query: Tensor, key: Tensor, query_pos: Tensor = None, ref_sine_embed: Tensor = None, key_pos: Tensor = None, attn_mask: Tensor = None, key_padding_mask: Tensor = None, is_first: bool = False) -> Tensor: """Forward function for `ConditionalAttention`. Args: query (Tensor): The input query with shape [bs, num_queries, embed_dims]. key (Tensor): The key tensor with shape [bs, num_keys, embed_dims]. If None, the `query` will be used. Defaults to None. query_pos (Tensor): The positional encoding for query in self attention, with the same shape as `x`. If not None, it will be added to `x` before forward function. Defaults to None. query_sine_embed (Tensor): The positional encoding for query in cross attention, with the same shape as `x`. If not None, it will be added to `x` before forward function. Defaults to None. key_pos (Tensor): The positional encoding for `key`, with the same shape as `key`. Defaults to None. If not None, it will be added to `key` before forward function. If None, and `query_pos` has the same shape as `key`, then `query_pos` will be used for `key_pos`. Defaults to None. attn_mask (Tensor): ByteTensor mask with shape [num_queries, num_keys]. Same in `nn.MultiheadAttention.forward`. Defaults to None. key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. Defaults to None. is_first (bool): A indicator to tell whether the current layer is the first layer of the decoder. Defaults to False. Returns: Tensor: forwarded results with shape [bs, num_queries, embed_dims]. """ if self.cross_attn: q_content = self.qcontent_proj(query) k_content = self.kcontent_proj(key) v = self.v_proj(key) bs, nq, c = q_content.size() _, hw, _ = k_content.size() k_pos = self.kpos_proj(key_pos) if is_first or self.keep_query_pos: q_pos = self.qpos_proj(query_pos) q = q_content + q_pos k = k_content + k_pos else: q = q_content k = k_content q = q.view(bs, nq, self.num_heads, c // self.num_heads) query_sine_embed = self.qpos_sine_proj(ref_sine_embed) query_sine_embed = query_sine_embed.view(bs, nq, self.num_heads, c // self.num_heads) q = torch.cat([q, query_sine_embed], dim=3).view(bs, nq, 2 * c) k = k.view(bs, hw, self.num_heads, c // self.num_heads) k_pos = k_pos.view(bs, hw, self.num_heads, c // self.num_heads) k = torch.cat([k, k_pos], dim=3).view(bs, hw, 2 * c) ca_output = self.forward_attn( query=q, key=k, value=v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0] query = query + self.proj_drop(ca_output) else: q_content = self.qcontent_proj(query) q_pos = self.qpos_proj(query_pos) k_content = self.kcontent_proj(query) k_pos = self.kpos_proj(query_pos) v = self.v_proj(query) q = q_content if q_pos is None else q_content + q_pos k = k_content if k_pos is None else k_content + k_pos sa_output = self.forward_attn( query=q, key=k, value=v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0] query = query + self.proj_drop(sa_output) return query class MLP(BaseModule): """Very simple multi-layer perceptron (also called FFN) with relu. Mostly used in DETR series detectors. Args: input_dim (int): Feature dim of the input tensor. hidden_dim (int): Feature dim of the hidden layer. output_dim (int): Feature dim of the output tensor. num_layers (int): Number of FFN layers. As the last layer of MLP only contains FFN (Linear). """ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int) -> None: super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = ModuleList( Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x: Tensor) -> Tensor: """Forward function of MLP. Args: x (Tensor): The input feature, has shape (num_queries, bs, input_dim). Returns: Tensor: The output feature, has shape (num_queries, bs, output_dim). """ for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x @MODELS.register_module() class DynamicConv(BaseModule): """Implements Dynamic Convolution. This module generate parameters for each sample and use bmm to implement 1*1 convolution. Code is modified from the `official github repo `_ . Args: in_channels (int): The input feature channel. Defaults to 256. feat_channels (int): The inner feature channel. Defaults to 64. out_channels (int, optional): The output feature channel. When not specified, it will be set to `in_channels` by default input_feat_shape (int): The shape of input feature. Defaults to 7. with_proj (bool): Project two-dimentional feature to one-dimentional feature. Default to True. act_cfg (dict): The activation config for DynamicConv. norm_cfg (dict): Config dict for normalization layer. Default layer normalization. init_cfg (obj:`mmengine.ConfigDict`): The Config for initialization. Default: None. """ def __init__(self, in_channels: int = 256, feat_channels: int = 64, out_channels: Optional[int] = None, input_feat_shape: int = 7, with_proj: bool = True, act_cfg: OptConfigType = dict(type='ReLU', inplace=True), norm_cfg: OptConfigType = dict(type='LN'), init_cfg: OptConfigType = None) -> None: super(DynamicConv, self).__init__(init_cfg) self.in_channels = in_channels self.feat_channels = feat_channels self.out_channels_raw = out_channels self.input_feat_shape = input_feat_shape self.with_proj = with_proj self.act_cfg = act_cfg self.norm_cfg = norm_cfg self.out_channels = out_channels if out_channels else in_channels self.num_params_in = self.in_channels * self.feat_channels self.num_params_out = self.out_channels * self.feat_channels self.dynamic_layer = nn.Linear( self.in_channels, self.num_params_in + self.num_params_out) self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1] self.activation = build_activation_layer(act_cfg) num_output = self.out_channels * input_feat_shape**2 if self.with_proj: self.fc_layer = nn.Linear(num_output, self.out_channels) self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] def forward(self, param_feature: Tensor, input_feature: Tensor) -> Tensor: """Forward function for `DynamicConv`. Args: param_feature (Tensor): The feature can be used to generate the parameter, has shape (num_all_proposals, in_channels). input_feature (Tensor): Feature that interact with parameters, has shape (num_all_proposals, in_channels, H, W). Returns: Tensor: The output feature has shape (num_all_proposals, out_channels). """ input_feature = input_feature.flatten(2).permute(2, 0, 1) input_feature = input_feature.permute(1, 0, 2) parameters = self.dynamic_layer(param_feature) param_in = parameters[:, :self.num_params_in].view( -1, self.in_channels, self.feat_channels) param_out = parameters[:, -self.num_params_out:].view( -1, self.feat_channels, self.out_channels) # input_feature has shape (num_all_proposals, H*W, in_channels) # param_in has shape (num_all_proposals, in_channels, feat_channels) # feature has shape (num_all_proposals, H*W, feat_channels) features = torch.bmm(input_feature, param_in) features = self.norm_in(features) features = self.activation(features) # param_out has shape (batch_size, feat_channels, out_channels) features = torch.bmm(features, param_out) features = self.norm_out(features) features = self.activation(features) if self.with_proj: features = features.flatten(1) features = self.fc_layer(features) features = self.fc_norm(features) features = self.activation(features) return features