mask2former_layers.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from mmcv.cnn import build_norm_layer
  3. from mmengine.model import ModuleList
  4. from torch import Tensor
  5. from .deformable_detr_layers import DeformableDetrTransformerEncoder
  6. from .detr_layers import DetrTransformerDecoder, DetrTransformerDecoderLayer
  7. class Mask2FormerTransformerEncoder(DeformableDetrTransformerEncoder):
  8. """Encoder in PixelDecoder of Mask2Former."""
  9. def forward(self, query: Tensor, query_pos: Tensor,
  10. key_padding_mask: Tensor, spatial_shapes: Tensor,
  11. level_start_index: Tensor, valid_ratios: Tensor,
  12. reference_points: Tensor, **kwargs) -> Tensor:
  13. """Forward function of Transformer encoder.
  14. Args:
  15. query (Tensor): The input query, has shape (bs, num_queries, dim).
  16. query_pos (Tensor): The positional encoding for query, has shape
  17. (bs, num_queries, dim). If not None, it will be added to the
  18. `query` before forward function. Defaults to None.
  19. key_padding_mask (Tensor): The `key_padding_mask` of `self_attn`
  20. input. ByteTensor, has shape (bs, num_queries).
  21. spatial_shapes (Tensor): Spatial shapes of features in all levels,
  22. has shape (num_levels, 2), last dimension represents (h, w).
  23. level_start_index (Tensor): The start index of each level.
  24. A tensor has shape (num_levels, ) and can be represented
  25. as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
  26. valid_ratios (Tensor): The ratios of the valid width and the valid
  27. height relative to the width and the height of features in all
  28. levels, has shape (bs, num_levels, 2).
  29. reference_points (Tensor): The initial reference, has shape
  30. (bs, num_queries, 2) with the last dimension arranged
  31. as (cx, cy).
  32. Returns:
  33. Tensor: Output queries of Transformer encoder, which is also
  34. called 'encoder output embeddings' or 'memory', has shape
  35. (bs, num_queries, dim)
  36. """
  37. for layer in self.layers:
  38. query = layer(
  39. query=query,
  40. query_pos=query_pos,
  41. key_padding_mask=key_padding_mask,
  42. spatial_shapes=spatial_shapes,
  43. level_start_index=level_start_index,
  44. valid_ratios=valid_ratios,
  45. reference_points=reference_points,
  46. **kwargs)
  47. return query
  48. class Mask2FormerTransformerDecoder(DetrTransformerDecoder):
  49. """Decoder of Mask2Former."""
  50. def _init_layers(self) -> None:
  51. """Initialize decoder layers."""
  52. self.layers = ModuleList([
  53. Mask2FormerTransformerDecoderLayer(**self.layer_cfg)
  54. for _ in range(self.num_layers)
  55. ])
  56. self.embed_dims = self.layers[0].embed_dims
  57. self.post_norm = build_norm_layer(self.post_norm_cfg,
  58. self.embed_dims)[1]
  59. class Mask2FormerTransformerDecoderLayer(DetrTransformerDecoderLayer):
  60. """Implements decoder layer in Mask2Former transformer."""
  61. def forward(self,
  62. query: Tensor,
  63. key: Tensor = None,
  64. value: Tensor = None,
  65. query_pos: Tensor = None,
  66. key_pos: Tensor = None,
  67. self_attn_mask: Tensor = None,
  68. cross_attn_mask: Tensor = None,
  69. key_padding_mask: Tensor = None,
  70. **kwargs) -> Tensor:
  71. """
  72. Args:
  73. query (Tensor): The input query, has shape (bs, num_queries, dim).
  74. key (Tensor, optional): The input key, has shape (bs, num_keys,
  75. dim). If `None`, the `query` will be used. Defaults to `None`.
  76. value (Tensor, optional): The input value, has the same shape as
  77. `key`, as in `nn.MultiheadAttention.forward`. If `None`, the
  78. `key` will be used. Defaults to `None`.
  79. query_pos (Tensor, optional): The positional encoding for `query`,
  80. has the same shape as `query`. If not `None`, it will be added
  81. to `query` before forward function. Defaults to `None`.
  82. key_pos (Tensor, optional): The positional encoding for `key`, has
  83. the same shape as `key`. If not `None`, it will be added to
  84. `key` before forward function. If None, and `query_pos` has the
  85. same shape as `key`, then `query_pos` will be used for
  86. `key_pos`. Defaults to None.
  87. self_attn_mask (Tensor, optional): ByteTensor mask, has shape
  88. (num_queries, num_keys), as in `nn.MultiheadAttention.forward`.
  89. Defaults to None.
  90. cross_attn_mask (Tensor, optional): ByteTensor mask, has shape
  91. (num_queries, num_keys), as in `nn.MultiheadAttention.forward`.
  92. Defaults to None.
  93. key_padding_mask (Tensor, optional): The `key_padding_mask` of
  94. `self_attn` input. ByteTensor, has shape (bs, num_value).
  95. Defaults to None.
  96. Returns:
  97. Tensor: forwarded results, has shape (bs, num_queries, dim).
  98. """
  99. query = self.cross_attn(
  100. query=query,
  101. key=key,
  102. value=value,
  103. query_pos=query_pos,
  104. key_pos=key_pos,
  105. attn_mask=cross_attn_mask,
  106. key_padding_mask=key_padding_mask,
  107. **kwargs)
  108. query = self.norms[0](query)
  109. query = self.self_attn(
  110. query=query,
  111. key=query,
  112. value=query,
  113. query_pos=query_pos,
  114. key_pos=query_pos,
  115. attn_mask=self_attn_mask,
  116. **kwargs)
  117. query = self.norms[1](query)
  118. query = self.ffn(query)
  119. query = self.norms[2](query)
  120. return query