positional_encoding.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. from mmengine.model import BaseModule
  6. from torch import Tensor
  7. from mmdet.registry import MODELS
  8. from mmdet.utils import MultiConfig, OptMultiConfig
  9. @MODELS.register_module()
  10. class SinePositionalEncoding(BaseModule):
  11. """Position encoding with sine and cosine functions.
  12. See `End-to-End Object Detection with Transformers
  13. <https://arxiv.org/pdf/2005.12872>`_ for details.
  14. Args:
  15. num_feats (int): The feature dimension for each position
  16. along x-axis or y-axis. Note the final returned dimension
  17. for each position is 2 times of this value.
  18. temperature (int, optional): The temperature used for scaling
  19. the position embedding. Defaults to 10000.
  20. normalize (bool, optional): Whether to normalize the position
  21. embedding. Defaults to False.
  22. scale (float, optional): A scale factor that scales the position
  23. embedding. The scale will be used only when `normalize` is True.
  24. Defaults to 2*pi.
  25. eps (float, optional): A value added to the denominator for
  26. numerical stability. Defaults to 1e-6.
  27. offset (float): offset add to embed when do the normalization.
  28. Defaults to 0.
  29. init_cfg (dict or list[dict], optional): Initialization config dict.
  30. Defaults to None
  31. """
  32. def __init__(self,
  33. num_feats: int,
  34. temperature: int = 10000,
  35. normalize: bool = False,
  36. scale: float = 2 * math.pi,
  37. eps: float = 1e-6,
  38. offset: float = 0.,
  39. init_cfg: OptMultiConfig = None) -> None:
  40. super().__init__(init_cfg=init_cfg)
  41. if normalize:
  42. assert isinstance(scale, (float, int)), 'when normalize is set,' \
  43. 'scale should be provided and in float or int type, ' \
  44. f'found {type(scale)}'
  45. self.num_feats = num_feats
  46. self.temperature = temperature
  47. self.normalize = normalize
  48. self.scale = scale
  49. self.eps = eps
  50. self.offset = offset
  51. def forward(self, mask: Tensor) -> Tensor:
  52. """Forward function for `SinePositionalEncoding`.
  53. Args:
  54. mask (Tensor): ByteTensor mask. Non-zero values representing
  55. ignored positions, while zero values means valid positions
  56. for this image. Shape [bs, h, w].
  57. Returns:
  58. pos (Tensor): Returned position embedding with shape
  59. [bs, num_feats*2, h, w].
  60. """
  61. # For convenience of exporting to ONNX, it's required to convert
  62. # `masks` from bool to int.
  63. mask = mask.to(torch.int)
  64. not_mask = 1 - mask # logical_not
  65. y_embed = not_mask.cumsum(1, dtype=torch.float32)
  66. x_embed = not_mask.cumsum(2, dtype=torch.float32)
  67. if self.normalize:
  68. y_embed = (y_embed + self.offset) / \
  69. (y_embed[:, -1:, :] + self.eps) * self.scale
  70. x_embed = (x_embed + self.offset) / \
  71. (x_embed[:, :, -1:] + self.eps) * self.scale
  72. dim_t = torch.arange(
  73. self.num_feats, dtype=torch.float32, device=mask.device)
  74. dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats)
  75. pos_x = x_embed[:, :, :, None] / dim_t
  76. pos_y = y_embed[:, :, :, None] / dim_t
  77. # use `view` instead of `flatten` for dynamically exporting to ONNX
  78. B, H, W = mask.size()
  79. pos_x = torch.stack(
  80. (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
  81. dim=4).view(B, H, W, -1)
  82. pos_y = torch.stack(
  83. (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
  84. dim=4).view(B, H, W, -1)
  85. pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
  86. return pos
  87. def __repr__(self) -> str:
  88. """str: a string that describes the module"""
  89. repr_str = self.__class__.__name__
  90. repr_str += f'(num_feats={self.num_feats}, '
  91. repr_str += f'temperature={self.temperature}, '
  92. repr_str += f'normalize={self.normalize}, '
  93. repr_str += f'scale={self.scale}, '
  94. repr_str += f'eps={self.eps})'
  95. return repr_str
  96. @MODELS.register_module()
  97. class LearnedPositionalEncoding(BaseModule):
  98. """Position embedding with learnable embedding weights.
  99. Args:
  100. num_feats (int): The feature dimension for each position
  101. along x-axis or y-axis. The final returned dimension for
  102. each position is 2 times of this value.
  103. row_num_embed (int, optional): The dictionary size of row embeddings.
  104. Defaults to 50.
  105. col_num_embed (int, optional): The dictionary size of col embeddings.
  106. Defaults to 50.
  107. init_cfg (dict or list[dict], optional): Initialization config dict.
  108. """
  109. def __init__(
  110. self,
  111. num_feats: int,
  112. row_num_embed: int = 50,
  113. col_num_embed: int = 50,
  114. init_cfg: MultiConfig = dict(type='Uniform', layer='Embedding')
  115. ) -> None:
  116. super().__init__(init_cfg=init_cfg)
  117. self.row_embed = nn.Embedding(row_num_embed, num_feats)
  118. self.col_embed = nn.Embedding(col_num_embed, num_feats)
  119. self.num_feats = num_feats
  120. self.row_num_embed = row_num_embed
  121. self.col_num_embed = col_num_embed
  122. def forward(self, mask: Tensor) -> Tensor:
  123. """Forward function for `LearnedPositionalEncoding`.
  124. Args:
  125. mask (Tensor): ByteTensor mask. Non-zero values representing
  126. ignored positions, while zero values means valid positions
  127. for this image. Shape [bs, h, w].
  128. Returns:
  129. pos (Tensor): Returned position embedding with shape
  130. [bs, num_feats*2, h, w].
  131. """
  132. h, w = mask.shape[-2:]
  133. x = torch.arange(w, device=mask.device)
  134. y = torch.arange(h, device=mask.device)
  135. x_embed = self.col_embed(x)
  136. y_embed = self.row_embed(y)
  137. pos = torch.cat(
  138. (x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat(
  139. 1, w, 1)),
  140. dim=-1).permute(2, 0,
  141. 1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1)
  142. return pos
  143. def __repr__(self) -> str:
  144. """str: a string that describes the module"""
  145. repr_str = self.__class__.__name__
  146. repr_str += f'(num_feats={self.num_feats}, '
  147. repr_str += f'row_num_embed={self.row_num_embed}, '
  148. repr_str += f'col_num_embed={self.col_num_embed})'
  149. return repr_str