detr_layers.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Union
  3. import torch
  4. from mmcv.cnn import build_norm_layer
  5. from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
  6. from mmengine import ConfigDict
  7. from mmengine.model import BaseModule, ModuleList
  8. from torch import Tensor
  9. from mmdet.utils import ConfigType, OptConfigType
  10. class DetrTransformerEncoder(BaseModule):
  11. """Encoder of DETR.
  12. Args:
  13. num_layers (int): Number of encoder layers.
  14. layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder
  15. layer. All the layers will share the same config.
  16. init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
  17. the initialization. Defaults to None.
  18. """
  19. def __init__(self,
  20. num_layers: int,
  21. layer_cfg: ConfigType,
  22. init_cfg: OptConfigType = None) -> None:
  23. super().__init__(init_cfg=init_cfg)
  24. self.num_layers = num_layers
  25. self.layer_cfg = layer_cfg
  26. self._init_layers()
  27. def _init_layers(self) -> None:
  28. """Initialize encoder layers."""
  29. self.layers = ModuleList([
  30. DetrTransformerEncoderLayer(**self.layer_cfg)
  31. for _ in range(self.num_layers)
  32. ])
  33. self.embed_dims = self.layers[0].embed_dims
  34. def forward(self, query: Tensor, query_pos: Tensor,
  35. key_padding_mask: Tensor, **kwargs) -> Tensor:
  36. """Forward function of encoder.
  37. Args:
  38. query (Tensor): Input queries of encoder, has shape
  39. (bs, num_queries, dim).
  40. query_pos (Tensor): The positional embeddings of the queries, has
  41. shape (bs, num_queries, dim).
  42. key_padding_mask (Tensor): The `key_padding_mask` of `self_attn`
  43. input. ByteTensor, has shape (bs, num_queries).
  44. Returns:
  45. Tensor: Has shape (bs, num_queries, dim) if `batch_first` is
  46. `True`, otherwise (num_queries, bs, dim).
  47. """
  48. for layer in self.layers:
  49. query = layer(query, query_pos, key_padding_mask, **kwargs)
  50. return query
  51. class DetrTransformerDecoder(BaseModule):
  52. """Decoder of DETR.
  53. Args:
  54. num_layers (int): Number of decoder layers.
  55. layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder
  56. layer. All the layers will share the same config.
  57. post_norm_cfg (:obj:`ConfigDict` or dict, optional): Config of the
  58. post normalization layer. Defaults to `LN`.
  59. return_intermediate (bool, optional): Whether to return outputs of
  60. intermediate layers. Defaults to `True`,
  61. init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
  62. the initialization. Defaults to None.
  63. """
  64. def __init__(self,
  65. num_layers: int,
  66. layer_cfg: ConfigType,
  67. post_norm_cfg: OptConfigType = dict(type='LN'),
  68. return_intermediate: bool = True,
  69. init_cfg: Union[dict, ConfigDict] = None) -> None:
  70. super().__init__(init_cfg=init_cfg)
  71. self.layer_cfg = layer_cfg
  72. self.num_layers = num_layers
  73. self.post_norm_cfg = post_norm_cfg
  74. self.return_intermediate = return_intermediate
  75. self._init_layers()
  76. def _init_layers(self) -> None:
  77. """Initialize decoder layers."""
  78. self.layers = ModuleList([
  79. DetrTransformerDecoderLayer(**self.layer_cfg)
  80. for _ in range(self.num_layers)
  81. ])
  82. self.embed_dims = self.layers[0].embed_dims
  83. self.post_norm = build_norm_layer(self.post_norm_cfg,
  84. self.embed_dims)[1]
  85. def forward(self, query: Tensor, key: Tensor, value: Tensor,
  86. query_pos: Tensor, key_pos: Tensor, key_padding_mask: Tensor,
  87. **kwargs) -> Tensor:
  88. """Forward function of decoder
  89. Args:
  90. query (Tensor): The input query, has shape (bs, num_queries, dim).
  91. key (Tensor): The input key, has shape (bs, num_keys, dim).
  92. value (Tensor): The input value with the same shape as `key`.
  93. query_pos (Tensor): The positional encoding for `query`, with the
  94. same shape as `query`.
  95. key_pos (Tensor): The positional encoding for `key`, with the
  96. same shape as `key`.
  97. key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn`
  98. input. ByteTensor, has shape (bs, num_value).
  99. Returns:
  100. Tensor: The forwarded results will have shape
  101. (num_decoder_layers, bs, num_queries, dim) if
  102. `return_intermediate` is `True` else (1, bs, num_queries, dim).
  103. """
  104. intermediate = []
  105. for layer in self.layers:
  106. query = layer(
  107. query,
  108. key=key,
  109. value=value,
  110. query_pos=query_pos,
  111. key_pos=key_pos,
  112. key_padding_mask=key_padding_mask,
  113. **kwargs)
  114. if self.return_intermediate:
  115. intermediate.append(self.post_norm(query))
  116. query = self.post_norm(query)
  117. if self.return_intermediate:
  118. return torch.stack(intermediate)
  119. return query.unsqueeze(0)
  120. class DetrTransformerEncoderLayer(BaseModule):
  121. """Implements encoder layer in DETR transformer.
  122. Args:
  123. self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self
  124. attention.
  125. ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN.
  126. norm_cfg (:obj:`ConfigDict` or dict, optional): Config for
  127. normalization layers. All the layers will share the same
  128. config. Defaults to `LN`.
  129. init_cfg (:obj:`ConfigDict` or dict, optional): Config to control
  130. the initialization. Defaults to None.
  131. """
  132. def __init__(self,
  133. self_attn_cfg: OptConfigType = dict(
  134. embed_dims=256, num_heads=8, dropout=0.0),
  135. ffn_cfg: OptConfigType = dict(
  136. embed_dims=256,
  137. feedforward_channels=1024,
  138. num_fcs=2,
  139. ffn_drop=0.,
  140. act_cfg=dict(type='ReLU', inplace=True)),
  141. norm_cfg: OptConfigType = dict(type='LN'),
  142. init_cfg: OptConfigType = None) -> None:
  143. super().__init__(init_cfg=init_cfg)
  144. self.self_attn_cfg = self_attn_cfg
  145. if 'batch_first' not in self.self_attn_cfg:
  146. self.self_attn_cfg['batch_first'] = True
  147. else:
  148. assert self.self_attn_cfg['batch_first'] is True, 'First \
  149. dimension of all DETRs in mmdet is `batch`, \
  150. please set `batch_first` flag.'
  151. self.ffn_cfg = ffn_cfg
  152. self.norm_cfg = norm_cfg
  153. self._init_layers()
  154. def _init_layers(self) -> None:
  155. """Initialize self-attention, FFN, and normalization."""
  156. self.self_attn = MultiheadAttention(**self.self_attn_cfg)
  157. self.embed_dims = self.self_attn.embed_dims
  158. self.ffn = FFN(**self.ffn_cfg)
  159. norms_list = [
  160. build_norm_layer(self.norm_cfg, self.embed_dims)[1]
  161. for _ in range(2)
  162. ]
  163. self.norms = ModuleList(norms_list)
  164. def forward(self, query: Tensor, query_pos: Tensor,
  165. key_padding_mask: Tensor, **kwargs) -> Tensor:
  166. """Forward function of an encoder layer.
  167. Args:
  168. query (Tensor): The input query, has shape (bs, num_queries, dim).
  169. query_pos (Tensor): The positional encoding for query, with
  170. the same shape as `query`.
  171. key_padding_mask (Tensor): The `key_padding_mask` of `self_attn`
  172. input. ByteTensor. has shape (bs, num_queries).
  173. Returns:
  174. Tensor: forwarded results, has shape (bs, num_queries, dim).
  175. """
  176. query = self.self_attn(
  177. query=query,
  178. key=query,
  179. value=query,
  180. query_pos=query_pos,
  181. key_pos=query_pos,
  182. key_padding_mask=key_padding_mask,
  183. **kwargs)
  184. query = self.norms[0](query)
  185. query = self.ffn(query)
  186. query = self.norms[1](query)
  187. return query
  188. class DetrTransformerDecoderLayer(BaseModule):
  189. """Implements decoder layer in DETR transformer.
  190. Args:
  191. self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self
  192. attention.
  193. cross_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for cross
  194. attention.
  195. ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN.
  196. norm_cfg (:obj:`ConfigDict` or dict, optional): Config for
  197. normalization layers. All the layers will share the same
  198. config. Defaults to `LN`.
  199. init_cfg (:obj:`ConfigDict` or dict, optional): Config to control
  200. the initialization. Defaults to None.
  201. """
  202. def __init__(self,
  203. self_attn_cfg: OptConfigType = dict(
  204. embed_dims=256,
  205. num_heads=8,
  206. dropout=0.0,
  207. batch_first=True),
  208. cross_attn_cfg: OptConfigType = dict(
  209. embed_dims=256,
  210. num_heads=8,
  211. dropout=0.0,
  212. batch_first=True),
  213. ffn_cfg: OptConfigType = dict(
  214. embed_dims=256,
  215. feedforward_channels=1024,
  216. num_fcs=2,
  217. ffn_drop=0.,
  218. act_cfg=dict(type='ReLU', inplace=True),
  219. ),
  220. norm_cfg: OptConfigType = dict(type='LN'),
  221. init_cfg: OptConfigType = None) -> None:
  222. super().__init__(init_cfg=init_cfg)
  223. self.self_attn_cfg = self_attn_cfg
  224. self.cross_attn_cfg = cross_attn_cfg
  225. if 'batch_first' not in self.self_attn_cfg:
  226. self.self_attn_cfg['batch_first'] = True
  227. else:
  228. assert self.self_attn_cfg['batch_first'] is True, 'First \
  229. dimension of all DETRs in mmdet is `batch`, \
  230. please set `batch_first` flag.'
  231. if 'batch_first' not in self.cross_attn_cfg:
  232. self.cross_attn_cfg['batch_first'] = True
  233. else:
  234. assert self.cross_attn_cfg['batch_first'] is True, 'First \
  235. dimension of all DETRs in mmdet is `batch`, \
  236. please set `batch_first` flag.'
  237. self.ffn_cfg = ffn_cfg
  238. self.norm_cfg = norm_cfg
  239. self._init_layers()
  240. def _init_layers(self) -> None:
  241. """Initialize self-attention, FFN, and normalization."""
  242. self.self_attn = MultiheadAttention(**self.self_attn_cfg)
  243. self.cross_attn = MultiheadAttention(**self.cross_attn_cfg)
  244. self.embed_dims = self.self_attn.embed_dims
  245. self.ffn = FFN(**self.ffn_cfg)
  246. norms_list = [
  247. build_norm_layer(self.norm_cfg, self.embed_dims)[1]
  248. for _ in range(3)
  249. ]
  250. self.norms = ModuleList(norms_list)
  251. def forward(self,
  252. query: Tensor,
  253. key: Tensor = None,
  254. value: Tensor = None,
  255. query_pos: Tensor = None,
  256. key_pos: Tensor = None,
  257. self_attn_mask: Tensor = None,
  258. cross_attn_mask: Tensor = None,
  259. key_padding_mask: Tensor = None,
  260. **kwargs) -> Tensor:
  261. """
  262. Args:
  263. query (Tensor): The input query, has shape (bs, num_queries, dim).
  264. key (Tensor, optional): The input key, has shape (bs, num_keys,
  265. dim). If `None`, the `query` will be used. Defaults to `None`.
  266. value (Tensor, optional): The input value, has the same shape as
  267. `key`, as in `nn.MultiheadAttention.forward`. If `None`, the
  268. `key` will be used. Defaults to `None`.
  269. query_pos (Tensor, optional): The positional encoding for `query`,
  270. has the same shape as `query`. If not `None`, it will be added
  271. to `query` before forward function. Defaults to `None`.
  272. key_pos (Tensor, optional): The positional encoding for `key`, has
  273. the same shape as `key`. If not `None`, it will be added to
  274. `key` before forward function. If None, and `query_pos` has the
  275. same shape as `key`, then `query_pos` will be used for
  276. `key_pos`. Defaults to None.
  277. self_attn_mask (Tensor, optional): ByteTensor mask, has shape
  278. (num_queries, num_keys), as in `nn.MultiheadAttention.forward`.
  279. Defaults to None.
  280. cross_attn_mask (Tensor, optional): ByteTensor mask, has shape
  281. (num_queries, num_keys), as in `nn.MultiheadAttention.forward`.
  282. Defaults to None.
  283. key_padding_mask (Tensor, optional): The `key_padding_mask` of
  284. `self_attn` input. ByteTensor, has shape (bs, num_value).
  285. Defaults to None.
  286. Returns:
  287. Tensor: forwarded results, has shape (bs, num_queries, dim).
  288. """
  289. query = self.self_attn(
  290. query=query,
  291. key=query,
  292. value=query,
  293. query_pos=query_pos,
  294. key_pos=query_pos,
  295. attn_mask=self_attn_mask,
  296. **kwargs)
  297. query = self.norms[0](query)
  298. query = self.cross_attn(
  299. query=query,
  300. key=key,
  301. value=value,
  302. query_pos=query_pos,
  303. key_pos=key_pos,
  304. attn_mask=cross_attn_mask,
  305. key_padding_mask=key_padding_mask,
  306. **kwargs)
  307. query = self.norms[1](query)
  308. query = self.ffn(query)
  309. query = self.norms[2](query)
  310. return query