conditional_detr_layers.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. from mmcv.cnn import build_norm_layer
  4. from mmcv.cnn.bricks.transformer import FFN
  5. from torch import Tensor
  6. from torch.nn import ModuleList
  7. from .detr_layers import DetrTransformerDecoder, DetrTransformerDecoderLayer
  8. from .utils import MLP, ConditionalAttention, coordinate_to_encoding
  9. class ConditionalDetrTransformerDecoder(DetrTransformerDecoder):
  10. """Decoder of Conditional DETR."""
  11. def _init_layers(self) -> None:
  12. """Initialize decoder layers and other layers."""
  13. self.layers = ModuleList([
  14. ConditionalDetrTransformerDecoderLayer(**self.layer_cfg)
  15. for _ in range(self.num_layers)
  16. ])
  17. self.embed_dims = self.layers[0].embed_dims
  18. self.post_norm = build_norm_layer(self.post_norm_cfg,
  19. self.embed_dims)[1]
  20. # conditional detr affline
  21. self.query_scale = MLP(self.embed_dims, self.embed_dims,
  22. self.embed_dims, 2)
  23. self.ref_point_head = MLP(self.embed_dims, self.embed_dims, 2, 2)
  24. # we have substitute 'qpos_proj' with 'qpos_sine_proj' except for
  25. # the first decoder layer), so 'qpos_proj' should be deleted
  26. # in other layers.
  27. for layer_id in range(self.num_layers - 1):
  28. self.layers[layer_id + 1].cross_attn.qpos_proj = None
  29. def forward(self,
  30. query: Tensor,
  31. key: Tensor = None,
  32. query_pos: Tensor = None,
  33. key_pos: Tensor = None,
  34. key_padding_mask: Tensor = None):
  35. """Forward function of decoder.
  36. Args:
  37. query (Tensor): The input query with shape
  38. (bs, num_queries, dim).
  39. key (Tensor): The input key with shape (bs, num_keys, dim) If
  40. `None`, the `query` will be used. Defaults to `None`.
  41. query_pos (Tensor): The positional encoding for `query`, with the
  42. same shape as `query`. If not `None`, it will be added to
  43. `query` before forward function. Defaults to `None`.
  44. key_pos (Tensor): The positional encoding for `key`, with the
  45. same shape as `key`. If not `None`, it will be added to
  46. `key` before forward function. If `None`, and `query_pos`
  47. has the same shape as `key`, then `query_pos` will be used
  48. as `key_pos`. Defaults to `None`.
  49. key_padding_mask (Tensor): ByteTensor with shape (bs, num_keys).
  50. Defaults to `None`.
  51. Returns:
  52. List[Tensor]: forwarded results with shape (num_decoder_layers,
  53. bs, num_queries, dim) if `return_intermediate` is True, otherwise
  54. with shape (1, bs, num_queries, dim). References with shape
  55. (bs, num_queries, 2).
  56. """
  57. reference_unsigmoid = self.ref_point_head(
  58. query_pos) # [bs, num_queries, 2]
  59. reference = reference_unsigmoid.sigmoid()
  60. reference_xy = reference[..., :2]
  61. intermediate = []
  62. for layer_id, layer in enumerate(self.layers):
  63. if layer_id == 0:
  64. pos_transformation = 1
  65. else:
  66. pos_transformation = self.query_scale(query)
  67. # get sine embedding for the query reference
  68. ref_sine_embed = coordinate_to_encoding(coord_tensor=reference_xy)
  69. # apply transformation
  70. ref_sine_embed = ref_sine_embed * pos_transformation
  71. query = layer(
  72. query,
  73. key=key,
  74. query_pos=query_pos,
  75. key_pos=key_pos,
  76. key_padding_mask=key_padding_mask,
  77. ref_sine_embed=ref_sine_embed,
  78. is_first=(layer_id == 0))
  79. if self.return_intermediate:
  80. intermediate.append(self.post_norm(query))
  81. if self.return_intermediate:
  82. return torch.stack(intermediate), reference
  83. query = self.post_norm(query)
  84. return query.unsqueeze(0), reference
  85. class ConditionalDetrTransformerDecoderLayer(DetrTransformerDecoderLayer):
  86. """Implements decoder layer in Conditional DETR transformer."""
  87. def _init_layers(self):
  88. """Initialize self-attention, cross-attention, FFN, and
  89. normalization."""
  90. self.self_attn = ConditionalAttention(**self.self_attn_cfg)
  91. self.cross_attn = ConditionalAttention(**self.cross_attn_cfg)
  92. self.embed_dims = self.self_attn.embed_dims
  93. self.ffn = FFN(**self.ffn_cfg)
  94. norms_list = [
  95. build_norm_layer(self.norm_cfg, self.embed_dims)[1]
  96. for _ in range(3)
  97. ]
  98. self.norms = ModuleList(norms_list)
  99. def forward(self,
  100. query: Tensor,
  101. key: Tensor = None,
  102. query_pos: Tensor = None,
  103. key_pos: Tensor = None,
  104. self_attn_masks: Tensor = None,
  105. cross_attn_masks: Tensor = None,
  106. key_padding_mask: Tensor = None,
  107. ref_sine_embed: Tensor = None,
  108. is_first: bool = False):
  109. """
  110. Args:
  111. query (Tensor): The input query, has shape (bs, num_queries, dim)
  112. key (Tensor, optional): The input key, has shape (bs, num_keys,
  113. dim). If `None`, the `query` will be used. Defaults to `None`.
  114. query_pos (Tensor, optional): The positional encoding for `query`,
  115. has the same shape as `query`. If not `None`, it will be
  116. added to `query` before forward function. Defaults to `None`.
  117. ref_sine_embed (Tensor): The positional encoding for query in
  118. cross attention, with the same shape as `x`. Defaults to None.
  119. key_pos (Tensor, optional): The positional encoding for `key`, has
  120. the same shape as `key`. If not None, it will be added to
  121. `key` before forward function. If None, and `query_pos` has
  122. the same shape as `key`, then `query_pos` will be used for
  123. `key_pos`. Defaults to None.
  124. self_attn_masks (Tensor, optional): ByteTensor mask, has shape
  125. (num_queries, num_keys), Same in `nn.MultiheadAttention.
  126. forward`. Defaults to None.
  127. cross_attn_masks (Tensor, optional): ByteTensor mask, has shape
  128. (num_queries, num_keys), Same in `nn.MultiheadAttention.
  129. forward`. Defaults to None.
  130. key_padding_mask (Tensor, optional): ByteTensor, has shape
  131. (bs, num_keys). Defaults to None.
  132. is_first (bool): A indicator to tell whether the current layer
  133. is the first layer of the decoder. Defaults to False.
  134. Returns:
  135. Tensor: Forwarded results, has shape (bs, num_queries, dim).
  136. """
  137. query = self.self_attn(
  138. query=query,
  139. key=query,
  140. query_pos=query_pos,
  141. key_pos=query_pos,
  142. attn_mask=self_attn_masks)
  143. query = self.norms[0](query)
  144. query = self.cross_attn(
  145. query=query,
  146. key=key,
  147. query_pos=query_pos,
  148. key_pos=key_pos,
  149. attn_mask=cross_attn_masks,
  150. key_padding_mask=key_padding_mask,
  151. ref_sine_embed=ref_sine_embed,
  152. is_first=is_first)
  153. query = self.norms[1](query)
  154. query = self.ffn(query)
  155. query = self.norms[2](query)
  156. return query