dab_detr_layers.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.cnn import build_norm_layer
  6. from mmcv.cnn.bricks.transformer import FFN
  7. from mmengine.model import ModuleList
  8. from torch import Tensor
  9. from .detr_layers import (DetrTransformerDecoder, DetrTransformerDecoderLayer,
  10. DetrTransformerEncoder, DetrTransformerEncoderLayer)
  11. from .utils import (MLP, ConditionalAttention, coordinate_to_encoding,
  12. inverse_sigmoid)
  13. class DABDetrTransformerDecoderLayer(DetrTransformerDecoderLayer):
  14. """Implements decoder layer in DAB-DETR transformer."""
  15. def _init_layers(self):
  16. """Initialize self-attention, cross-attention, FFN, normalization and
  17. others."""
  18. self.self_attn = ConditionalAttention(**self.self_attn_cfg)
  19. self.cross_attn = ConditionalAttention(**self.cross_attn_cfg)
  20. self.embed_dims = self.self_attn.embed_dims
  21. self.ffn = FFN(**self.ffn_cfg)
  22. norms_list = [
  23. build_norm_layer(self.norm_cfg, self.embed_dims)[1]
  24. for _ in range(3)
  25. ]
  26. self.norms = ModuleList(norms_list)
  27. self.keep_query_pos = self.cross_attn.keep_query_pos
  28. def forward(self,
  29. query: Tensor,
  30. key: Tensor,
  31. query_pos: Tensor,
  32. key_pos: Tensor,
  33. ref_sine_embed: Tensor = None,
  34. self_attn_masks: Tensor = None,
  35. cross_attn_masks: Tensor = None,
  36. key_padding_mask: Tensor = None,
  37. is_first: bool = False,
  38. **kwargs) -> Tensor:
  39. """
  40. Args:
  41. query (Tensor): The input query with shape [bs, num_queries,
  42. dim].
  43. key (Tensor): The key tensor with shape [bs, num_keys,
  44. dim].
  45. query_pos (Tensor): The positional encoding for query in self
  46. attention, with the same shape as `x`.
  47. key_pos (Tensor): The positional encoding for `key`, with the
  48. same shape as `key`.
  49. ref_sine_embed (Tensor): The positional encoding for query in
  50. cross attention, with the same shape as `x`.
  51. Defaults to None.
  52. self_attn_masks (Tensor): ByteTensor mask with shape [num_queries,
  53. num_keys]. Same in `nn.MultiheadAttention.forward`.
  54. Defaults to None.
  55. cross_attn_masks (Tensor): ByteTensor mask with shape [num_queries,
  56. num_keys]. Same in `nn.MultiheadAttention.forward`.
  57. Defaults to None.
  58. key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
  59. Defaults to None.
  60. is_first (bool): A indicator to tell whether the current layer
  61. is the first layer of the decoder.
  62. Defaults to False.
  63. Returns:
  64. Tensor: forwarded results with shape
  65. [bs, num_queries, dim].
  66. """
  67. query = self.self_attn(
  68. query=query,
  69. key=query,
  70. query_pos=query_pos,
  71. key_pos=query_pos,
  72. attn_mask=self_attn_masks,
  73. **kwargs)
  74. query = self.norms[0](query)
  75. query = self.cross_attn(
  76. query=query,
  77. key=key,
  78. query_pos=query_pos,
  79. key_pos=key_pos,
  80. ref_sine_embed=ref_sine_embed,
  81. attn_mask=cross_attn_masks,
  82. key_padding_mask=key_padding_mask,
  83. is_first=is_first,
  84. **kwargs)
  85. query = self.norms[1](query)
  86. query = self.ffn(query)
  87. query = self.norms[2](query)
  88. return query
  89. class DABDetrTransformerDecoder(DetrTransformerDecoder):
  90. """Decoder of DAB-DETR.
  91. Args:
  92. query_dim (int): The last dimension of query pos,
  93. 4 for anchor format, 2 for point format.
  94. Defaults to 4.
  95. query_scale_type (str): Type of transformation applied
  96. to content query. Defaults to `cond_elewise`.
  97. with_modulated_hw_attn (bool): Whether to inject h&w info
  98. during cross conditional attention. Defaults to True.
  99. """
  100. def __init__(self,
  101. *args,
  102. query_dim: int = 4,
  103. query_scale_type: str = 'cond_elewise',
  104. with_modulated_hw_attn: bool = True,
  105. **kwargs):
  106. self.query_dim = query_dim
  107. self.query_scale_type = query_scale_type
  108. self.with_modulated_hw_attn = with_modulated_hw_attn
  109. super().__init__(*args, **kwargs)
  110. def _init_layers(self):
  111. """Initialize decoder layers and other layers."""
  112. assert self.query_dim in [2, 4], \
  113. f'{"dab-detr only supports anchor prior or reference point prior"}'
  114. assert self.query_scale_type in [
  115. 'cond_elewise', 'cond_scalar', 'fix_elewise'
  116. ]
  117. self.layers = ModuleList([
  118. DABDetrTransformerDecoderLayer(**self.layer_cfg)
  119. for _ in range(self.num_layers)
  120. ])
  121. embed_dims = self.layers[0].embed_dims
  122. self.embed_dims = embed_dims
  123. self.post_norm = build_norm_layer(self.post_norm_cfg, embed_dims)[1]
  124. if self.query_scale_type == 'cond_elewise':
  125. self.query_scale = MLP(embed_dims, embed_dims, embed_dims, 2)
  126. elif self.query_scale_type == 'cond_scalar':
  127. self.query_scale = MLP(embed_dims, embed_dims, 1, 2)
  128. elif self.query_scale_type == 'fix_elewise':
  129. self.query_scale = nn.Embedding(self.num_layers, embed_dims)
  130. else:
  131. raise NotImplementedError('Unknown query_scale_type: {}'.format(
  132. self.query_scale_type))
  133. self.ref_point_head = MLP(self.query_dim // 2 * embed_dims, embed_dims,
  134. embed_dims, 2)
  135. if self.with_modulated_hw_attn and self.query_dim == 4:
  136. self.ref_anchor_head = MLP(embed_dims, embed_dims, 2, 2)
  137. self.keep_query_pos = self.layers[0].keep_query_pos
  138. if not self.keep_query_pos:
  139. for layer_id in range(self.num_layers - 1):
  140. self.layers[layer_id + 1].cross_attn.qpos_proj = None
  141. def forward(self,
  142. query: Tensor,
  143. key: Tensor,
  144. query_pos: Tensor,
  145. key_pos: Tensor,
  146. reg_branches: nn.Module,
  147. key_padding_mask: Tensor = None,
  148. **kwargs) -> List[Tensor]:
  149. """Forward function of decoder.
  150. Args:
  151. query (Tensor): The input query with shape (bs, num_queries, dim).
  152. key (Tensor): The input key with shape (bs, num_keys, dim).
  153. query_pos (Tensor): The positional encoding for `query`, with the
  154. same shape as `query`.
  155. key_pos (Tensor): The positional encoding for `key`, with the
  156. same shape as `key`.
  157. reg_branches (nn.Module): The regression branch for dynamically
  158. updating references in each layer.
  159. key_padding_mask (Tensor): ByteTensor with shape (bs, num_keys).
  160. Defaults to `None`.
  161. Returns:
  162. List[Tensor]: forwarded results with shape (num_decoder_layers,
  163. bs, num_queries, dim) if `return_intermediate` is True, otherwise
  164. with shape (1, bs, num_queries, dim). references with shape
  165. (num_decoder_layers, bs, num_queries, 2/4).
  166. """
  167. output = query
  168. unsigmoid_references = query_pos
  169. reference_points = unsigmoid_references.sigmoid()
  170. intermediate_reference_points = [reference_points]
  171. intermediate = []
  172. for layer_id, layer in enumerate(self.layers):
  173. obj_center = reference_points[..., :self.query_dim]
  174. ref_sine_embed = coordinate_to_encoding(
  175. coord_tensor=obj_center, num_feats=self.embed_dims // 2)
  176. query_pos = self.ref_point_head(
  177. ref_sine_embed) # [bs, nq, 2c] -> [bs, nq, c]
  178. # For the first decoder layer, do not apply transformation
  179. if self.query_scale_type != 'fix_elewise':
  180. if layer_id == 0:
  181. pos_transformation = 1
  182. else:
  183. pos_transformation = self.query_scale(output)
  184. else:
  185. pos_transformation = self.query_scale.weight[layer_id]
  186. # apply transformation
  187. ref_sine_embed = ref_sine_embed[
  188. ..., :self.embed_dims] * pos_transformation
  189. # modulated height and weight attention
  190. if self.with_modulated_hw_attn:
  191. assert obj_center.size(-1) == 4
  192. ref_hw = self.ref_anchor_head(output).sigmoid()
  193. ref_sine_embed[..., self.embed_dims // 2:] *= \
  194. (ref_hw[..., 0] / obj_center[..., 2]).unsqueeze(-1)
  195. ref_sine_embed[..., : self.embed_dims // 2] *= \
  196. (ref_hw[..., 1] / obj_center[..., 3]).unsqueeze(-1)
  197. output = layer(
  198. output,
  199. key,
  200. query_pos=query_pos,
  201. ref_sine_embed=ref_sine_embed,
  202. key_pos=key_pos,
  203. key_padding_mask=key_padding_mask,
  204. is_first=(layer_id == 0),
  205. **kwargs)
  206. # iter update
  207. tmp_reg_preds = reg_branches(output)
  208. tmp_reg_preds[..., :self.query_dim] += inverse_sigmoid(
  209. reference_points)
  210. new_reference_points = tmp_reg_preds[
  211. ..., :self.query_dim].sigmoid()
  212. if layer_id != self.num_layers - 1:
  213. intermediate_reference_points.append(new_reference_points)
  214. reference_points = new_reference_points.detach()
  215. if self.return_intermediate:
  216. intermediate.append(self.post_norm(output))
  217. output = self.post_norm(output)
  218. if self.return_intermediate:
  219. return [
  220. torch.stack(intermediate),
  221. torch.stack(intermediate_reference_points),
  222. ]
  223. else:
  224. return [
  225. output.unsqueeze(0),
  226. torch.stack(intermediate_reference_points)
  227. ]
  228. class DABDetrTransformerEncoder(DetrTransformerEncoder):
  229. """Encoder of DAB-DETR."""
  230. def _init_layers(self):
  231. """Initialize encoder layers."""
  232. self.layers = ModuleList([
  233. DetrTransformerEncoderLayer(**self.layer_cfg)
  234. for _ in range(self.num_layers)
  235. ])
  236. embed_dims = self.layers[0].embed_dims
  237. self.embed_dims = embed_dims
  238. self.query_scale = MLP(embed_dims, embed_dims, embed_dims, 2)
  239. def forward(self, query: Tensor, query_pos: Tensor,
  240. key_padding_mask: Tensor, **kwargs):
  241. """Forward function of encoder.
  242. Args:
  243. query (Tensor): Input queries of encoder, has shape
  244. (bs, num_queries, dim).
  245. query_pos (Tensor): The positional embeddings of the queries, has
  246. shape (bs, num_feat_points, dim).
  247. key_padding_mask (Tensor): ByteTensor, the key padding mask
  248. of the queries, has shape (bs, num_feat_points).
  249. Returns:
  250. Tensor: With shape (num_queries, bs, dim).
  251. """
  252. for layer in self.layers:
  253. pos_scales = self.query_scale(query)
  254. query = layer(
  255. query,
  256. query_pos=query_pos * pos_scales,
  257. key_padding_mask=key_padding_mask,
  258. **kwargs)
  259. return query