deformable_detr_layers.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Tuple, Union
  3. import torch
  4. from mmcv.cnn import build_norm_layer
  5. from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
  6. from mmcv.ops import MultiScaleDeformableAttention
  7. from mmengine.model import ModuleList
  8. from torch import Tensor, nn
  9. from .detr_layers import (DetrTransformerDecoder, DetrTransformerDecoderLayer,
  10. DetrTransformerEncoder, DetrTransformerEncoderLayer)
  11. from .utils import inverse_sigmoid
  12. class DeformableDetrTransformerEncoder(DetrTransformerEncoder):
  13. """Transformer encoder of Deformable DETR."""
  14. def _init_layers(self) -> None:
  15. """Initialize encoder layers."""
  16. self.layers = ModuleList([
  17. DeformableDetrTransformerEncoderLayer(**self.layer_cfg)
  18. for _ in range(self.num_layers)
  19. ])
  20. self.embed_dims = self.layers[0].embed_dims
  21. def forward(self, query: Tensor, query_pos: Tensor,
  22. key_padding_mask: Tensor, spatial_shapes: Tensor,
  23. level_start_index: Tensor, valid_ratios: Tensor,
  24. **kwargs) -> Tensor:
  25. """Forward function of Transformer encoder.
  26. Args:
  27. query (Tensor): The input query, has shape (bs, num_queries, dim).
  28. query_pos (Tensor): The positional encoding for query, has shape
  29. (bs, num_queries, dim).
  30. key_padding_mask (Tensor): The `key_padding_mask` of `self_attn`
  31. input. ByteTensor, has shape (bs, num_queries).
  32. spatial_shapes (Tensor): Spatial shapes of features in all levels,
  33. has shape (num_levels, 2), last dimension represents (h, w).
  34. level_start_index (Tensor): The start index of each level.
  35. A tensor has shape (num_levels, ) and can be represented
  36. as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
  37. valid_ratios (Tensor): The ratios of the valid width and the valid
  38. height relative to the width and the height of features in all
  39. levels, has shape (bs, num_levels, 2).
  40. Returns:
  41. Tensor: Output queries of Transformer encoder, which is also
  42. called 'encoder output embeddings' or 'memory', has shape
  43. (bs, num_queries, dim)
  44. """
  45. reference_points = self.get_encoder_reference_points(
  46. spatial_shapes, valid_ratios, device=query.device)
  47. for layer in self.layers:
  48. query = layer(
  49. query=query,
  50. query_pos=query_pos,
  51. key_padding_mask=key_padding_mask,
  52. spatial_shapes=spatial_shapes,
  53. level_start_index=level_start_index,
  54. valid_ratios=valid_ratios,
  55. reference_points=reference_points,
  56. **kwargs)
  57. return query
  58. @staticmethod
  59. def get_encoder_reference_points(
  60. spatial_shapes: Tensor, valid_ratios: Tensor,
  61. device: Union[torch.device, str]) -> Tensor:
  62. """Get the reference points used in encoder.
  63. Args:
  64. spatial_shapes (Tensor): Spatial shapes of features in all levels,
  65. has shape (num_levels, 2), last dimension represents (h, w).
  66. valid_ratios (Tensor): The ratios of the valid width and the valid
  67. height relative to the width and the height of features in all
  68. levels, has shape (bs, num_levels, 2).
  69. device (obj:`device` or str): The device acquired by the
  70. `reference_points`.
  71. Returns:
  72. Tensor: Reference points used in decoder, has shape (bs, length,
  73. num_levels, 2).
  74. """
  75. reference_points_list = []
  76. for lvl, (H, W) in enumerate(spatial_shapes):
  77. ref_y, ref_x = torch.meshgrid(
  78. torch.linspace(
  79. 0.5, H - 0.5, H, dtype=torch.float32, device=device),
  80. torch.linspace(
  81. 0.5, W - 0.5, W, dtype=torch.float32, device=device))
  82. ref_y = ref_y.reshape(-1)[None] / (
  83. valid_ratios[:, None, lvl, 1] * H)
  84. ref_x = ref_x.reshape(-1)[None] / (
  85. valid_ratios[:, None, lvl, 0] * W)
  86. ref = torch.stack((ref_x, ref_y), -1)
  87. reference_points_list.append(ref)
  88. reference_points = torch.cat(reference_points_list, 1)
  89. # [bs, sum(hw), num_level, 2]
  90. reference_points = reference_points[:, :, None] * valid_ratios[:, None]
  91. return reference_points
  92. class DeformableDetrTransformerDecoder(DetrTransformerDecoder):
  93. """Transformer Decoder of Deformable DETR."""
  94. def _init_layers(self) -> None:
  95. """Initialize decoder layers."""
  96. self.layers = ModuleList([
  97. DeformableDetrTransformerDecoderLayer(**self.layer_cfg)
  98. for _ in range(self.num_layers)
  99. ])
  100. self.embed_dims = self.layers[0].embed_dims
  101. if self.post_norm_cfg is not None:
  102. raise ValueError('There is not post_norm in '
  103. f'{self._get_name()}')
  104. def forward(self,
  105. query: Tensor,
  106. query_pos: Tensor,
  107. value: Tensor,
  108. key_padding_mask: Tensor,
  109. reference_points: Tensor,
  110. spatial_shapes: Tensor,
  111. level_start_index: Tensor,
  112. valid_ratios: Tensor,
  113. reg_branches: Optional[nn.Module] = None,
  114. **kwargs) -> Tuple[Tensor]:
  115. """Forward function of Transformer decoder.
  116. Args:
  117. query (Tensor): The input queries, has shape (bs, num_queries,
  118. dim).
  119. query_pos (Tensor): The input positional query, has shape
  120. (bs, num_queries, dim). It will be added to `query` before
  121. forward function.
  122. value (Tensor): The input values, has shape (bs, num_value, dim).
  123. key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn`
  124. input. ByteTensor, has shape (bs, num_value).
  125. reference_points (Tensor): The initial reference, has shape
  126. (bs, num_queries, 4) with the last dimension arranged as
  127. (cx, cy, w, h) when `as_two_stage` is `True`, otherwise has
  128. shape (bs, num_queries, 2) with the last dimension arranged
  129. as (cx, cy).
  130. spatial_shapes (Tensor): Spatial shapes of features in all levels,
  131. has shape (num_levels, 2), last dimension represents (h, w).
  132. level_start_index (Tensor): The start index of each level.
  133. A tensor has shape (num_levels, ) and can be represented
  134. as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
  135. valid_ratios (Tensor): The ratios of the valid width and the valid
  136. height relative to the width and the height of features in all
  137. levels, has shape (bs, num_levels, 2).
  138. reg_branches: (obj:`nn.ModuleList`, optional): Used for refining
  139. the regression results. Only would be passed when
  140. `with_box_refine` is `True`, otherwise would be `None`.
  141. Returns:
  142. tuple[Tensor]: Outputs of Deformable Transformer Decoder.
  143. - output (Tensor): Output embeddings of the last decoder, has
  144. shape (num_queries, bs, embed_dims) when `return_intermediate`
  145. is `False`. Otherwise, Intermediate output embeddings of all
  146. decoder layers, has shape (num_decoder_layers, num_queries, bs,
  147. embed_dims).
  148. - reference_points (Tensor): The reference of the last decoder
  149. layer, has shape (bs, num_queries, 4) when `return_intermediate`
  150. is `False`. Otherwise, Intermediate references of all decoder
  151. layers, has shape (num_decoder_layers, bs, num_queries, 4). The
  152. coordinates are arranged as (cx, cy, w, h)
  153. """
  154. output = query
  155. intermediate = []
  156. intermediate_reference_points = []
  157. for layer_id, layer in enumerate(self.layers):
  158. if reference_points.shape[-1] == 4:
  159. reference_points_input = \
  160. reference_points[:, :, None] * \
  161. torch.cat([valid_ratios, valid_ratios], -1)[:, None]
  162. else:
  163. assert reference_points.shape[-1] == 2
  164. reference_points_input = \
  165. reference_points[:, :, None] * \
  166. valid_ratios[:, None]
  167. output = layer(
  168. output,
  169. query_pos=query_pos,
  170. value=value,
  171. key_padding_mask=key_padding_mask,
  172. spatial_shapes=spatial_shapes,
  173. level_start_index=level_start_index,
  174. valid_ratios=valid_ratios,
  175. reference_points=reference_points_input,
  176. **kwargs)
  177. if reg_branches is not None:
  178. tmp_reg_preds = reg_branches[layer_id](output)
  179. if reference_points.shape[-1] == 4:
  180. new_reference_points = tmp_reg_preds + inverse_sigmoid(
  181. reference_points)
  182. new_reference_points = new_reference_points.sigmoid()
  183. else:
  184. assert reference_points.shape[-1] == 2
  185. new_reference_points = tmp_reg_preds
  186. new_reference_points[..., :2] = tmp_reg_preds[
  187. ..., :2] + inverse_sigmoid(reference_points)
  188. new_reference_points = new_reference_points.sigmoid()
  189. reference_points = new_reference_points.detach()
  190. if self.return_intermediate:
  191. intermediate.append(output)
  192. intermediate_reference_points.append(reference_points)
  193. if self.return_intermediate:
  194. return torch.stack(intermediate), torch.stack(
  195. intermediate_reference_points)
  196. return output, reference_points
  197. class DeformableDetrTransformerEncoderLayer(DetrTransformerEncoderLayer):
  198. """Encoder layer of Deformable DETR."""
  199. def _init_layers(self) -> None:
  200. """Initialize self_attn, ffn, and norms."""
  201. self.self_attn = MultiScaleDeformableAttention(**self.self_attn_cfg)
  202. self.embed_dims = self.self_attn.embed_dims
  203. self.ffn = FFN(**self.ffn_cfg)
  204. norms_list = [
  205. build_norm_layer(self.norm_cfg, self.embed_dims)[1]
  206. for _ in range(2)
  207. ]
  208. self.norms = ModuleList(norms_list)
  209. class DeformableDetrTransformerDecoderLayer(DetrTransformerDecoderLayer):
  210. """Decoder layer of Deformable DETR."""
  211. def _init_layers(self) -> None:
  212. """Initialize self_attn, cross-attn, ffn, and norms."""
  213. self.self_attn = MultiheadAttention(**self.self_attn_cfg)
  214. self.cross_attn = MultiScaleDeformableAttention(**self.cross_attn_cfg)
  215. self.embed_dims = self.self_attn.embed_dims
  216. self.ffn = FFN(**self.ffn_cfg)
  217. norms_list = [
  218. build_norm_layer(self.norm_cfg, self.embed_dims)[1]
  219. for _ in range(3)
  220. ]
  221. self.norms = ModuleList(norms_list)