utils.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. import warnings
  4. from typing import Optional, Sequence, Tuple, Union
  5. import torch
  6. import torch.nn.functional as F
  7. from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer,
  8. build_norm_layer)
  9. from mmcv.cnn.bricks.drop import Dropout
  10. from mmengine.model import BaseModule, ModuleList
  11. from mmengine.utils import to_2tuple
  12. from torch import Tensor, nn
  13. from mmdet.registry import MODELS
  14. from mmdet.utils import OptConfigType, OptMultiConfig
  15. def nlc_to_nchw(x: Tensor, hw_shape: Sequence[int]) -> Tensor:
  16. """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
  17. Args:
  18. x (Tensor): The input tensor of shape [N, L, C] before conversion.
  19. hw_shape (Sequence[int]): The height and width of output feature map.
  20. Returns:
  21. Tensor: The output tensor of shape [N, C, H, W] after conversion.
  22. """
  23. H, W = hw_shape
  24. assert len(x.shape) == 3
  25. B, L, C = x.shape
  26. assert L == H * W, 'The seq_len does not match H, W'
  27. return x.transpose(1, 2).reshape(B, C, H, W).contiguous()
  28. def nchw_to_nlc(x):
  29. """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
  30. Args:
  31. x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
  32. Returns:
  33. Tensor: The output tensor of shape [N, L, C] after conversion.
  34. """
  35. assert len(x.shape) == 4
  36. return x.flatten(2).transpose(1, 2).contiguous()
  37. def coordinate_to_encoding(coord_tensor: Tensor,
  38. num_feats: int = 128,
  39. temperature: int = 10000,
  40. scale: float = 2 * math.pi):
  41. """Convert coordinate tensor to positional encoding.
  42. Args:
  43. coord_tensor (Tensor): Coordinate tensor to be converted to
  44. positional encoding. With the last dimension as 2 or 4.
  45. num_feats (int, optional): The feature dimension for each position
  46. along x-axis or y-axis. Note the final returned dimension
  47. for each position is 2 times of this value. Defaults to 128.
  48. temperature (int, optional): The temperature used for scaling
  49. the position embedding. Defaults to 10000.
  50. scale (float, optional): A scale factor that scales the position
  51. embedding. The scale will be used only when `normalize` is True.
  52. Defaults to 2*pi.
  53. Returns:
  54. Tensor: Returned encoded positional tensor.
  55. """
  56. dim_t = torch.arange(
  57. num_feats, dtype=torch.float32, device=coord_tensor.device)
  58. dim_t = temperature**(2 * (dim_t // 2) / num_feats)
  59. x_embed = coord_tensor[..., 0] * scale
  60. y_embed = coord_tensor[..., 1] * scale
  61. pos_x = x_embed[..., None] / dim_t
  62. pos_y = y_embed[..., None] / dim_t
  63. pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()),
  64. dim=-1).flatten(2)
  65. pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()),
  66. dim=-1).flatten(2)
  67. if coord_tensor.size(-1) == 2:
  68. pos = torch.cat((pos_y, pos_x), dim=-1)
  69. elif coord_tensor.size(-1) == 4:
  70. w_embed = coord_tensor[..., 2] * scale
  71. pos_w = w_embed[..., None] / dim_t
  72. pos_w = torch.stack((pos_w[..., 0::2].sin(), pos_w[..., 1::2].cos()),
  73. dim=-1).flatten(2)
  74. h_embed = coord_tensor[..., 3] * scale
  75. pos_h = h_embed[..., None] / dim_t
  76. pos_h = torch.stack((pos_h[..., 0::2].sin(), pos_h[..., 1::2].cos()),
  77. dim=-1).flatten(2)
  78. pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=-1)
  79. else:
  80. raise ValueError('Unknown pos_tensor shape(-1):{}'.format(
  81. coord_tensor.size(-1)))
  82. return pos
  83. def inverse_sigmoid(x: Tensor, eps: float = 1e-5) -> Tensor:
  84. """Inverse function of sigmoid.
  85. Args:
  86. x (Tensor): The tensor to do the inverse.
  87. eps (float): EPS avoid numerical overflow. Defaults 1e-5.
  88. Returns:
  89. Tensor: The x has passed the inverse function of sigmoid, has the same
  90. shape with input.
  91. """
  92. x = x.clamp(min=0, max=1)
  93. x1 = x.clamp(min=eps)
  94. x2 = (1 - x).clamp(min=eps)
  95. return torch.log(x1 / x2)
  96. class AdaptivePadding(nn.Module):
  97. """Applies padding to input (if needed) so that input can get fully covered
  98. by filter you specified. It support two modes "same" and "corner". The
  99. "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
  100. input. The "corner" mode would pad zero to bottom right.
  101. Args:
  102. kernel_size (int | tuple): Size of the kernel:
  103. stride (int | tuple): Stride of the filter. Default: 1:
  104. dilation (int | tuple): Spacing between kernel elements.
  105. Default: 1
  106. padding (str): Support "same" and "corner", "corner" mode
  107. would pad zero to bottom right, and "same" mode would
  108. pad zero around input. Default: "corner".
  109. Example:
  110. >>> kernel_size = 16
  111. >>> stride = 16
  112. >>> dilation = 1
  113. >>> input = torch.rand(1, 1, 15, 17)
  114. >>> adap_pad = AdaptivePadding(
  115. >>> kernel_size=kernel_size,
  116. >>> stride=stride,
  117. >>> dilation=dilation,
  118. >>> padding="corner")
  119. >>> out = adap_pad(input)
  120. >>> assert (out.shape[2], out.shape[3]) == (16, 32)
  121. >>> input = torch.rand(1, 1, 16, 17)
  122. >>> out = adap_pad(input)
  123. >>> assert (out.shape[2], out.shape[3]) == (16, 32)
  124. """
  125. def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
  126. super(AdaptivePadding, self).__init__()
  127. assert padding in ('same', 'corner')
  128. kernel_size = to_2tuple(kernel_size)
  129. stride = to_2tuple(stride)
  130. padding = to_2tuple(padding)
  131. dilation = to_2tuple(dilation)
  132. self.padding = padding
  133. self.kernel_size = kernel_size
  134. self.stride = stride
  135. self.dilation = dilation
  136. def get_pad_shape(self, input_shape):
  137. input_h, input_w = input_shape
  138. kernel_h, kernel_w = self.kernel_size
  139. stride_h, stride_w = self.stride
  140. output_h = math.ceil(input_h / stride_h)
  141. output_w = math.ceil(input_w / stride_w)
  142. pad_h = max((output_h - 1) * stride_h +
  143. (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
  144. pad_w = max((output_w - 1) * stride_w +
  145. (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
  146. return pad_h, pad_w
  147. def forward(self, x):
  148. pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
  149. if pad_h > 0 or pad_w > 0:
  150. if self.padding == 'corner':
  151. x = F.pad(x, [0, pad_w, 0, pad_h])
  152. elif self.padding == 'same':
  153. x = F.pad(x, [
  154. pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
  155. pad_h - pad_h // 2
  156. ])
  157. return x
  158. class PatchEmbed(BaseModule):
  159. """Image to Patch Embedding.
  160. We use a conv layer to implement PatchEmbed.
  161. Args:
  162. in_channels (int): The num of input channels. Default: 3
  163. embed_dims (int): The dimensions of embedding. Default: 768
  164. conv_type (str): The config dict for embedding
  165. conv layer type selection. Default: "Conv2d.
  166. kernel_size (int): The kernel_size of embedding conv. Default: 16.
  167. stride (int): The slide stride of embedding conv.
  168. Default: None (Would be set as `kernel_size`).
  169. padding (int | tuple | string ): The padding length of
  170. embedding conv. When it is a string, it means the mode
  171. of adaptive padding, support "same" and "corner" now.
  172. Default: "corner".
  173. dilation (int): The dilation rate of embedding conv. Default: 1.
  174. bias (bool): Bias of embed conv. Default: True.
  175. norm_cfg (dict, optional): Config dict for normalization layer.
  176. Default: None.
  177. input_size (int | tuple | None): The size of input, which will be
  178. used to calculate the out size. Only work when `dynamic_size`
  179. is False. Default: None.
  180. init_cfg (`mmengine.ConfigDict`, optional): The Config for
  181. initialization. Default: None.
  182. """
  183. def __init__(self,
  184. in_channels: int = 3,
  185. embed_dims: int = 768,
  186. conv_type: str = 'Conv2d',
  187. kernel_size: int = 16,
  188. stride: int = 16,
  189. padding: Union[int, tuple, str] = 'corner',
  190. dilation: int = 1,
  191. bias: bool = True,
  192. norm_cfg: OptConfigType = None,
  193. input_size: Union[int, tuple] = None,
  194. init_cfg: OptConfigType = None) -> None:
  195. super(PatchEmbed, self).__init__(init_cfg=init_cfg)
  196. self.embed_dims = embed_dims
  197. if stride is None:
  198. stride = kernel_size
  199. kernel_size = to_2tuple(kernel_size)
  200. stride = to_2tuple(stride)
  201. dilation = to_2tuple(dilation)
  202. if isinstance(padding, str):
  203. self.adap_padding = AdaptivePadding(
  204. kernel_size=kernel_size,
  205. stride=stride,
  206. dilation=dilation,
  207. padding=padding)
  208. # disable the padding of conv
  209. padding = 0
  210. else:
  211. self.adap_padding = None
  212. padding = to_2tuple(padding)
  213. self.projection = build_conv_layer(
  214. dict(type=conv_type),
  215. in_channels=in_channels,
  216. out_channels=embed_dims,
  217. kernel_size=kernel_size,
  218. stride=stride,
  219. padding=padding,
  220. dilation=dilation,
  221. bias=bias)
  222. if norm_cfg is not None:
  223. self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
  224. else:
  225. self.norm = None
  226. if input_size:
  227. input_size = to_2tuple(input_size)
  228. # `init_out_size` would be used outside to
  229. # calculate the num_patches
  230. # when `use_abs_pos_embed` outside
  231. self.init_input_size = input_size
  232. if self.adap_padding:
  233. pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
  234. input_h, input_w = input_size
  235. input_h = input_h + pad_h
  236. input_w = input_w + pad_w
  237. input_size = (input_h, input_w)
  238. # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
  239. h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
  240. (kernel_size[0] - 1) - 1) // stride[0] + 1
  241. w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
  242. (kernel_size[1] - 1) - 1) // stride[1] + 1
  243. self.init_out_size = (h_out, w_out)
  244. else:
  245. self.init_input_size = None
  246. self.init_out_size = None
  247. def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int]]:
  248. """
  249. Args:
  250. x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
  251. Returns:
  252. tuple: Contains merged results and its spatial shape.
  253. - x (Tensor): Has shape (B, out_h * out_w, embed_dims)
  254. - out_size (tuple[int]): Spatial shape of x, arrange as
  255. (out_h, out_w).
  256. """
  257. if self.adap_padding:
  258. x = self.adap_padding(x)
  259. x = self.projection(x)
  260. out_size = (x.shape[2], x.shape[3])
  261. x = x.flatten(2).transpose(1, 2)
  262. if self.norm is not None:
  263. x = self.norm(x)
  264. return x, out_size
  265. class PatchMerging(BaseModule):
  266. """Merge patch feature map.
  267. This layer groups feature map by kernel_size, and applies norm and linear
  268. layers to the grouped feature map. Our implementation uses `nn.Unfold` to
  269. merge patch, which is about 25% faster than original implementation.
  270. Instead, we need to modify pretrained models for compatibility.
  271. Args:
  272. in_channels (int): The num of input channels.
  273. to gets fully covered by filter and stride you specified..
  274. Default: True.
  275. out_channels (int): The num of output channels.
  276. kernel_size (int | tuple, optional): the kernel size in the unfold
  277. layer. Defaults to 2.
  278. stride (int | tuple, optional): the stride of the sliding blocks in the
  279. unfold layer. Default: None. (Would be set as `kernel_size`)
  280. padding (int | tuple | string ): The padding length of
  281. embedding conv. When it is a string, it means the mode
  282. of adaptive padding, support "same" and "corner" now.
  283. Default: "corner".
  284. dilation (int | tuple, optional): dilation parameter in the unfold
  285. layer. Default: 1.
  286. bias (bool, optional): Whether to add bias in linear layer or not.
  287. Defaults: False.
  288. norm_cfg (dict, optional): Config dict for normalization layer.
  289. Default: dict(type='LN').
  290. init_cfg (dict, optional): The extra config for initialization.
  291. Default: None.
  292. """
  293. def __init__(self,
  294. in_channels: int,
  295. out_channels: int,
  296. kernel_size: Optional[Union[int, tuple]] = 2,
  297. stride: Optional[Union[int, tuple]] = None,
  298. padding: Union[int, tuple, str] = 'corner',
  299. dilation: Optional[Union[int, tuple]] = 1,
  300. bias: Optional[bool] = False,
  301. norm_cfg: OptConfigType = dict(type='LN'),
  302. init_cfg: OptConfigType = None) -> None:
  303. super().__init__(init_cfg=init_cfg)
  304. self.in_channels = in_channels
  305. self.out_channels = out_channels
  306. if stride:
  307. stride = stride
  308. else:
  309. stride = kernel_size
  310. kernel_size = to_2tuple(kernel_size)
  311. stride = to_2tuple(stride)
  312. dilation = to_2tuple(dilation)
  313. if isinstance(padding, str):
  314. self.adap_padding = AdaptivePadding(
  315. kernel_size=kernel_size,
  316. stride=stride,
  317. dilation=dilation,
  318. padding=padding)
  319. # disable the padding of unfold
  320. padding = 0
  321. else:
  322. self.adap_padding = None
  323. padding = to_2tuple(padding)
  324. self.sampler = nn.Unfold(
  325. kernel_size=kernel_size,
  326. dilation=dilation,
  327. padding=padding,
  328. stride=stride)
  329. sample_dim = kernel_size[0] * kernel_size[1] * in_channels
  330. if norm_cfg is not None:
  331. self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
  332. else:
  333. self.norm = None
  334. self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
  335. def forward(self, x: Tensor,
  336. input_size: Tuple[int]) -> Tuple[Tensor, Tuple[int]]:
  337. """
  338. Args:
  339. x (Tensor): Has shape (B, H*W, C_in).
  340. input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
  341. Default: None.
  342. Returns:
  343. tuple: Contains merged results and its spatial shape.
  344. - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
  345. - out_size (tuple[int]): Spatial shape of x, arrange as
  346. (Merged_H, Merged_W).
  347. """
  348. B, L, C = x.shape
  349. assert isinstance(input_size, Sequence), f'Expect ' \
  350. f'input_size is ' \
  351. f'`Sequence` ' \
  352. f'but get {input_size}'
  353. H, W = input_size
  354. assert L == H * W, 'input feature has wrong size'
  355. x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
  356. # Use nn.Unfold to merge patch. About 25% faster than original method,
  357. # but need to modify pretrained model for compatibility
  358. if self.adap_padding:
  359. x = self.adap_padding(x)
  360. H, W = x.shape[-2:]
  361. x = self.sampler(x)
  362. # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
  363. out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
  364. (self.sampler.kernel_size[0] - 1) -
  365. 1) // self.sampler.stride[0] + 1
  366. out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
  367. (self.sampler.kernel_size[1] - 1) -
  368. 1) // self.sampler.stride[1] + 1
  369. output_size = (out_h, out_w)
  370. x = x.transpose(1, 2) # B, H/2*W/2, 4*C
  371. x = self.norm(x) if self.norm else x
  372. x = self.reduction(x)
  373. return x, output_size
  374. class ConditionalAttention(BaseModule):
  375. """A wrapper of conditional attention, dropout and residual connection.
  376. Args:
  377. embed_dims (int): The embedding dimension.
  378. num_heads (int): Parallel attention heads.
  379. attn_drop (float): A Dropout layer on attn_output_weights.
  380. Default: 0.0.
  381. proj_drop: A Dropout layer after `nn.MultiheadAttention`.
  382. Default: 0.0.
  383. cross_attn (bool): Whether the attention module is for cross attention.
  384. Default: False
  385. keep_query_pos (bool): Whether to transform query_pos before cross
  386. attention.
  387. Default: False.
  388. batch_first (bool): When it is True, Key, Query and Value are shape of
  389. (batch, n, embed_dim), otherwise (n, batch, embed_dim).
  390. Default: True.
  391. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
  392. Default: None.
  393. """
  394. def __init__(self,
  395. embed_dims: int,
  396. num_heads: int,
  397. attn_drop: float = 0.,
  398. proj_drop: float = 0.,
  399. cross_attn: bool = False,
  400. keep_query_pos: bool = False,
  401. batch_first: bool = True,
  402. init_cfg: OptMultiConfig = None):
  403. super().__init__(init_cfg=init_cfg)
  404. assert batch_first is True, 'Set `batch_first`\
  405. to False is NOT supported in ConditionalAttention. \
  406. First dimension of all DETRs in mmdet is `batch`, \
  407. please set `batch_first` to True.'
  408. self.cross_attn = cross_attn
  409. self.keep_query_pos = keep_query_pos
  410. self.embed_dims = embed_dims
  411. self.num_heads = num_heads
  412. self.attn_drop = Dropout(attn_drop)
  413. self.proj_drop = Dropout(proj_drop)
  414. self._init_layers()
  415. def _init_layers(self):
  416. """Initialize layers for qkv projection."""
  417. embed_dims = self.embed_dims
  418. self.qcontent_proj = Linear(embed_dims, embed_dims)
  419. self.qpos_proj = Linear(embed_dims, embed_dims)
  420. self.kcontent_proj = Linear(embed_dims, embed_dims)
  421. self.kpos_proj = Linear(embed_dims, embed_dims)
  422. self.v_proj = Linear(embed_dims, embed_dims)
  423. if self.cross_attn:
  424. self.qpos_sine_proj = Linear(embed_dims, embed_dims)
  425. self.out_proj = Linear(embed_dims, embed_dims)
  426. nn.init.constant_(self.out_proj.bias, 0.)
  427. def forward_attn(self,
  428. query: Tensor,
  429. key: Tensor,
  430. value: Tensor,
  431. attn_mask: Tensor = None,
  432. key_padding_mask: Tensor = None) -> Tuple[Tensor]:
  433. """Forward process for `ConditionalAttention`.
  434. Args:
  435. query (Tensor): The input query with shape [bs, num_queries,
  436. embed_dims].
  437. key (Tensor): The key tensor with shape [bs, num_keys,
  438. embed_dims].
  439. If None, the `query` will be used. Defaults to None.
  440. value (Tensor): The value tensor with same shape as `key`.
  441. Same in `nn.MultiheadAttention.forward`. Defaults to None.
  442. If None, the `key` will be used.
  443. attn_mask (Tensor): ByteTensor mask with shape [num_queries,
  444. num_keys]. Same in `nn.MultiheadAttention.forward`.
  445. Defaults to None.
  446. key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
  447. Defaults to None.
  448. Returns:
  449. Tuple[Tensor]: Attention outputs of shape :math:`(N, L, E)`,
  450. where :math:`N` is the batch size, :math:`L` is the target
  451. sequence length , and :math:`E` is the embedding dimension
  452. `embed_dim`. Attention weights per head of shape :math:`
  453. (num_heads, L, S)`. where :math:`N` is batch size, :math:`L`
  454. is target sequence length, and :math:`S` is the source sequence
  455. length.
  456. """
  457. assert key.size(1) == value.size(1), \
  458. f'{"key, value must have the same sequence length"}'
  459. assert query.size(0) == key.size(0) == value.size(0), \
  460. f'{"batch size must be equal for query, key, value"}'
  461. assert query.size(2) == key.size(2), \
  462. f'{"q_dims, k_dims must be equal"}'
  463. assert value.size(2) == self.embed_dims, \
  464. f'{"v_dims must be equal to embed_dims"}'
  465. bs, tgt_len, hidden_dims = query.size()
  466. _, src_len, _ = key.size()
  467. head_dims = hidden_dims // self.num_heads
  468. v_head_dims = self.embed_dims // self.num_heads
  469. assert head_dims * self.num_heads == hidden_dims, \
  470. f'{"hidden_dims must be divisible by num_heads"}'
  471. scaling = float(head_dims)**-0.5
  472. q = query * scaling
  473. k = key
  474. v = value
  475. if attn_mask is not None:
  476. assert attn_mask.dtype == torch.float32 or \
  477. attn_mask.dtype == torch.float64 or \
  478. attn_mask.dtype == torch.float16 or \
  479. attn_mask.dtype == torch.uint8 or \
  480. attn_mask.dtype == torch.bool, \
  481. 'Only float, byte, and bool types are supported for \
  482. attn_mask'
  483. if attn_mask.dtype == torch.uint8:
  484. warnings.warn('Byte tensor for attn_mask is deprecated.\
  485. Use bool tensor instead.')
  486. attn_mask = attn_mask.to(torch.bool)
  487. if attn_mask.dim() == 2:
  488. attn_mask = attn_mask.unsqueeze(0)
  489. if list(attn_mask.size()) != [1, query.size(1), key.size(1)]:
  490. raise RuntimeError(
  491. 'The size of the 2D attn_mask is not correct.')
  492. elif attn_mask.dim() == 3:
  493. if list(attn_mask.size()) != [
  494. bs * self.num_heads,
  495. query.size(1),
  496. key.size(1)
  497. ]:
  498. raise RuntimeError(
  499. 'The size of the 3D attn_mask is not correct.')
  500. else:
  501. raise RuntimeError(
  502. "attn_mask's dimension {} is not supported".format(
  503. attn_mask.dim()))
  504. # attn_mask's dim is 3 now.
  505. if key_padding_mask is not None and key_padding_mask.dtype == int:
  506. key_padding_mask = key_padding_mask.to(torch.bool)
  507. q = q.contiguous().view(bs, tgt_len, self.num_heads,
  508. head_dims).permute(0, 2, 1, 3).flatten(0, 1)
  509. if k is not None:
  510. k = k.contiguous().view(bs, src_len, self.num_heads,
  511. head_dims).permute(0, 2, 1,
  512. 3).flatten(0, 1)
  513. if v is not None:
  514. v = v.contiguous().view(bs, src_len, self.num_heads,
  515. v_head_dims).permute(0, 2, 1,
  516. 3).flatten(0, 1)
  517. if key_padding_mask is not None:
  518. assert key_padding_mask.size(0) == bs
  519. assert key_padding_mask.size(1) == src_len
  520. attn_output_weights = torch.bmm(q, k.transpose(1, 2))
  521. assert list(attn_output_weights.size()) == [
  522. bs * self.num_heads, tgt_len, src_len
  523. ]
  524. if attn_mask is not None:
  525. if attn_mask.dtype == torch.bool:
  526. attn_output_weights.masked_fill_(attn_mask, float('-inf'))
  527. else:
  528. attn_output_weights += attn_mask
  529. if key_padding_mask is not None:
  530. attn_output_weights = attn_output_weights.view(
  531. bs, self.num_heads, tgt_len, src_len)
  532. attn_output_weights = attn_output_weights.masked_fill(
  533. key_padding_mask.unsqueeze(1).unsqueeze(2),
  534. float('-inf'),
  535. )
  536. attn_output_weights = attn_output_weights.view(
  537. bs * self.num_heads, tgt_len, src_len)
  538. attn_output_weights = F.softmax(
  539. attn_output_weights -
  540. attn_output_weights.max(dim=-1, keepdim=True)[0],
  541. dim=-1)
  542. attn_output_weights = self.attn_drop(attn_output_weights)
  543. attn_output = torch.bmm(attn_output_weights, v)
  544. assert list(
  545. attn_output.size()) == [bs * self.num_heads, tgt_len, v_head_dims]
  546. attn_output = attn_output.view(bs, self.num_heads, tgt_len,
  547. v_head_dims).permute(0, 2, 1,
  548. 3).flatten(2)
  549. attn_output = self.out_proj(attn_output)
  550. # average attention weights over heads
  551. attn_output_weights = attn_output_weights.view(bs, self.num_heads,
  552. tgt_len, src_len)
  553. return attn_output, attn_output_weights.sum(dim=1) / self.num_heads
  554. def forward(self,
  555. query: Tensor,
  556. key: Tensor,
  557. query_pos: Tensor = None,
  558. ref_sine_embed: Tensor = None,
  559. key_pos: Tensor = None,
  560. attn_mask: Tensor = None,
  561. key_padding_mask: Tensor = None,
  562. is_first: bool = False) -> Tensor:
  563. """Forward function for `ConditionalAttention`.
  564. Args:
  565. query (Tensor): The input query with shape [bs, num_queries,
  566. embed_dims].
  567. key (Tensor): The key tensor with shape [bs, num_keys,
  568. embed_dims].
  569. If None, the `query` will be used. Defaults to None.
  570. query_pos (Tensor): The positional encoding for query in self
  571. attention, with the same shape as `x`. If not None, it will
  572. be added to `x` before forward function.
  573. Defaults to None.
  574. query_sine_embed (Tensor): The positional encoding for query in
  575. cross attention, with the same shape as `x`. If not None, it
  576. will be added to `x` before forward function.
  577. Defaults to None.
  578. key_pos (Tensor): The positional encoding for `key`, with the
  579. same shape as `key`. Defaults to None. If not None, it will
  580. be added to `key` before forward function. If None, and
  581. `query_pos` has the same shape as `key`, then `query_pos`
  582. will be used for `key_pos`. Defaults to None.
  583. attn_mask (Tensor): ByteTensor mask with shape [num_queries,
  584. num_keys]. Same in `nn.MultiheadAttention.forward`.
  585. Defaults to None.
  586. key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
  587. Defaults to None.
  588. is_first (bool): A indicator to tell whether the current layer
  589. is the first layer of the decoder.
  590. Defaults to False.
  591. Returns:
  592. Tensor: forwarded results with shape
  593. [bs, num_queries, embed_dims].
  594. """
  595. if self.cross_attn:
  596. q_content = self.qcontent_proj(query)
  597. k_content = self.kcontent_proj(key)
  598. v = self.v_proj(key)
  599. bs, nq, c = q_content.size()
  600. _, hw, _ = k_content.size()
  601. k_pos = self.kpos_proj(key_pos)
  602. if is_first or self.keep_query_pos:
  603. q_pos = self.qpos_proj(query_pos)
  604. q = q_content + q_pos
  605. k = k_content + k_pos
  606. else:
  607. q = q_content
  608. k = k_content
  609. q = q.view(bs, nq, self.num_heads, c // self.num_heads)
  610. query_sine_embed = self.qpos_sine_proj(ref_sine_embed)
  611. query_sine_embed = query_sine_embed.view(bs, nq, self.num_heads,
  612. c // self.num_heads)
  613. q = torch.cat([q, query_sine_embed], dim=3).view(bs, nq, 2 * c)
  614. k = k.view(bs, hw, self.num_heads, c // self.num_heads)
  615. k_pos = k_pos.view(bs, hw, self.num_heads, c // self.num_heads)
  616. k = torch.cat([k, k_pos], dim=3).view(bs, hw, 2 * c)
  617. ca_output = self.forward_attn(
  618. query=q,
  619. key=k,
  620. value=v,
  621. attn_mask=attn_mask,
  622. key_padding_mask=key_padding_mask)[0]
  623. query = query + self.proj_drop(ca_output)
  624. else:
  625. q_content = self.qcontent_proj(query)
  626. q_pos = self.qpos_proj(query_pos)
  627. k_content = self.kcontent_proj(query)
  628. k_pos = self.kpos_proj(query_pos)
  629. v = self.v_proj(query)
  630. q = q_content if q_pos is None else q_content + q_pos
  631. k = k_content if k_pos is None else k_content + k_pos
  632. sa_output = self.forward_attn(
  633. query=q,
  634. key=k,
  635. value=v,
  636. attn_mask=attn_mask,
  637. key_padding_mask=key_padding_mask)[0]
  638. query = query + self.proj_drop(sa_output)
  639. return query
  640. class MLP(BaseModule):
  641. """Very simple multi-layer perceptron (also called FFN) with relu. Mostly
  642. used in DETR series detectors.
  643. Args:
  644. input_dim (int): Feature dim of the input tensor.
  645. hidden_dim (int): Feature dim of the hidden layer.
  646. output_dim (int): Feature dim of the output tensor.
  647. num_layers (int): Number of FFN layers. As the last
  648. layer of MLP only contains FFN (Linear).
  649. """
  650. def __init__(self, input_dim: int, hidden_dim: int, output_dim: int,
  651. num_layers: int) -> None:
  652. super().__init__()
  653. self.num_layers = num_layers
  654. h = [hidden_dim] * (num_layers - 1)
  655. self.layers = ModuleList(
  656. Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
  657. def forward(self, x: Tensor) -> Tensor:
  658. """Forward function of MLP.
  659. Args:
  660. x (Tensor): The input feature, has shape
  661. (num_queries, bs, input_dim).
  662. Returns:
  663. Tensor: The output feature, has shape
  664. (num_queries, bs, output_dim).
  665. """
  666. for i, layer in enumerate(self.layers):
  667. x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
  668. return x
  669. @MODELS.register_module()
  670. class DynamicConv(BaseModule):
  671. """Implements Dynamic Convolution.
  672. This module generate parameters for each sample and
  673. use bmm to implement 1*1 convolution. Code is modified
  674. from the `official github repo <https://github.com/PeizeSun/
  675. SparseR-CNN/blob/main/projects/SparseRCNN/sparsercnn/head.py#L258>`_ .
  676. Args:
  677. in_channels (int): The input feature channel.
  678. Defaults to 256.
  679. feat_channels (int): The inner feature channel.
  680. Defaults to 64.
  681. out_channels (int, optional): The output feature channel.
  682. When not specified, it will be set to `in_channels`
  683. by default
  684. input_feat_shape (int): The shape of input feature.
  685. Defaults to 7.
  686. with_proj (bool): Project two-dimentional feature to
  687. one-dimentional feature. Default to True.
  688. act_cfg (dict): The activation config for DynamicConv.
  689. norm_cfg (dict): Config dict for normalization layer. Default
  690. layer normalization.
  691. init_cfg (obj:`mmengine.ConfigDict`): The Config for initialization.
  692. Default: None.
  693. """
  694. def __init__(self,
  695. in_channels: int = 256,
  696. feat_channels: int = 64,
  697. out_channels: Optional[int] = None,
  698. input_feat_shape: int = 7,
  699. with_proj: bool = True,
  700. act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
  701. norm_cfg: OptConfigType = dict(type='LN'),
  702. init_cfg: OptConfigType = None) -> None:
  703. super(DynamicConv, self).__init__(init_cfg)
  704. self.in_channels = in_channels
  705. self.feat_channels = feat_channels
  706. self.out_channels_raw = out_channels
  707. self.input_feat_shape = input_feat_shape
  708. self.with_proj = with_proj
  709. self.act_cfg = act_cfg
  710. self.norm_cfg = norm_cfg
  711. self.out_channels = out_channels if out_channels else in_channels
  712. self.num_params_in = self.in_channels * self.feat_channels
  713. self.num_params_out = self.out_channels * self.feat_channels
  714. self.dynamic_layer = nn.Linear(
  715. self.in_channels, self.num_params_in + self.num_params_out)
  716. self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
  717. self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1]
  718. self.activation = build_activation_layer(act_cfg)
  719. num_output = self.out_channels * input_feat_shape**2
  720. if self.with_proj:
  721. self.fc_layer = nn.Linear(num_output, self.out_channels)
  722. self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
  723. def forward(self, param_feature: Tensor, input_feature: Tensor) -> Tensor:
  724. """Forward function for `DynamicConv`.
  725. Args:
  726. param_feature (Tensor): The feature can be used
  727. to generate the parameter, has shape
  728. (num_all_proposals, in_channels).
  729. input_feature (Tensor): Feature that
  730. interact with parameters, has shape
  731. (num_all_proposals, in_channels, H, W).
  732. Returns:
  733. Tensor: The output feature has shape
  734. (num_all_proposals, out_channels).
  735. """
  736. input_feature = input_feature.flatten(2).permute(2, 0, 1)
  737. input_feature = input_feature.permute(1, 0, 2)
  738. parameters = self.dynamic_layer(param_feature)
  739. param_in = parameters[:, :self.num_params_in].view(
  740. -1, self.in_channels, self.feat_channels)
  741. param_out = parameters[:, -self.num_params_out:].view(
  742. -1, self.feat_channels, self.out_channels)
  743. # input_feature has shape (num_all_proposals, H*W, in_channels)
  744. # param_in has shape (num_all_proposals, in_channels, feat_channels)
  745. # feature has shape (num_all_proposals, H*W, feat_channels)
  746. features = torch.bmm(input_feature, param_in)
  747. features = self.norm_in(features)
  748. features = self.activation(features)
  749. # param_out has shape (batch_size, feat_channels, out_channels)
  750. features = torch.bmm(features, param_out)
  751. features = self.norm_out(features)
  752. features = self.activation(features)
  753. if self.with_proj:
  754. features = features.flatten(1)
  755. features = self.fc_layer(features)
  756. features = self.fc_norm(features)
  757. features = self.activation(features)
  758. return features