detr.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, Tuple
  3. import torch
  4. import torch.nn.functional as F
  5. from torch import Tensor, nn
  6. from mmdet.registry import MODELS
  7. from mmdet.structures import OptSampleList
  8. from ..layers import (DetrTransformerDecoder, DetrTransformerEncoder,
  9. SinePositionalEncoding)
  10. from .base_detr import DetectionTransformer
  11. @MODELS.register_module()
  12. class DETR(DetectionTransformer):
  13. r"""Implementation of `DETR: End-to-End Object Detection with Transformers.
  14. <https://arxiv.org/pdf/2005.12872>`_.
  15. Code is modified from the `official github repo
  16. <https://github.com/facebookresearch/detr>`_.
  17. """
  18. def _init_layers(self) -> None:
  19. """Initialize layers except for backbone, neck and bbox_head."""
  20. self.positional_encoding = SinePositionalEncoding(
  21. **self.positional_encoding)
  22. self.encoder = DetrTransformerEncoder(**self.encoder)
  23. self.decoder = DetrTransformerDecoder(**self.decoder)
  24. self.embed_dims = self.encoder.embed_dims
  25. # NOTE The embed_dims is typically passed from the inside out.
  26. # For example in DETR, The embed_dims is passed as
  27. # self_attn -> the first encoder layer -> encoder -> detector.
  28. self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims)
  29. num_feats = self.positional_encoding.num_feats
  30. assert num_feats * 2 == self.embed_dims, \
  31. 'embed_dims should be exactly 2 times of num_feats. ' \
  32. f'Found {self.embed_dims} and {num_feats}.'
  33. def init_weights(self) -> None:
  34. """Initialize weights for Transformer and other components."""
  35. super().init_weights()
  36. for coder in self.encoder, self.decoder:
  37. for p in coder.parameters():
  38. if p.dim() > 1:
  39. nn.init.xavier_uniform_(p)
  40. def pre_transformer(
  41. self,
  42. img_feats: Tuple[Tensor],
  43. batch_data_samples: OptSampleList = None) -> Tuple[Dict, Dict]:
  44. """Prepare the inputs of the Transformer.
  45. The forward procedure of the transformer is defined as:
  46. 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
  47. More details can be found at `TransformerDetector.forward_transformer`
  48. in `mmdet/detector/base_detr.py`.
  49. Args:
  50. img_feats (Tuple[Tensor]): Tuple of features output from the neck,
  51. has shape (bs, c, h, w).
  52. batch_data_samples (List[:obj:`DetDataSample`]): The batch
  53. data samples. It usually includes information such as
  54. `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  55. Defaults to None.
  56. Returns:
  57. tuple[dict, dict]: The first dict contains the inputs of encoder
  58. and the second dict contains the inputs of decoder.
  59. - encoder_inputs_dict (dict): The keyword args dictionary of
  60. `self.forward_encoder()`, which includes 'feat', 'feat_mask',
  61. and 'feat_pos'.
  62. - decoder_inputs_dict (dict): The keyword args dictionary of
  63. `self.forward_decoder()`, which includes 'memory_mask',
  64. and 'memory_pos'.
  65. """
  66. feat = img_feats[-1] # NOTE img_feats contains only one feature.
  67. batch_size, feat_dim, _, _ = feat.shape
  68. # construct binary masks which for the transformer.
  69. assert batch_data_samples is not None
  70. batch_input_shape = batch_data_samples[0].batch_input_shape
  71. img_shape_list = [sample.img_shape for sample in batch_data_samples]
  72. input_img_h, input_img_w = batch_input_shape
  73. masks = feat.new_ones((batch_size, input_img_h, input_img_w))
  74. for img_id in range(batch_size):
  75. img_h, img_w = img_shape_list[img_id]
  76. masks[img_id, :img_h, :img_w] = 0
  77. # NOTE following the official DETR repo, non-zero values represent
  78. # ignored positions, while zero values mean valid positions.
  79. masks = F.interpolate(
  80. masks.unsqueeze(1), size=feat.shape[-2:]).to(torch.bool).squeeze(1)
  81. # [batch_size, embed_dim, h, w]
  82. pos_embed = self.positional_encoding(masks)
  83. # use `view` instead of `flatten` for dynamically exporting to ONNX
  84. # [bs, c, h, w] -> [bs, h*w, c]
  85. feat = feat.view(batch_size, feat_dim, -1).permute(0, 2, 1)
  86. pos_embed = pos_embed.view(batch_size, feat_dim, -1).permute(0, 2, 1)
  87. # [bs, h, w] -> [bs, h*w]
  88. masks = masks.view(batch_size, -1)
  89. # prepare transformer_inputs_dict
  90. encoder_inputs_dict = dict(
  91. feat=feat, feat_mask=masks, feat_pos=pos_embed)
  92. decoder_inputs_dict = dict(memory_mask=masks, memory_pos=pos_embed)
  93. return encoder_inputs_dict, decoder_inputs_dict
  94. def forward_encoder(self, feat: Tensor, feat_mask: Tensor,
  95. feat_pos: Tensor) -> Dict:
  96. """Forward with Transformer encoder.
  97. The forward procedure of the transformer is defined as:
  98. 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
  99. More details can be found at `TransformerDetector.forward_transformer`
  100. in `mmdet/detector/base_detr.py`.
  101. Args:
  102. feat (Tensor): Sequential features, has shape (bs, num_feat_points,
  103. dim).
  104. feat_mask (Tensor): ByteTensor, the padding mask of the features,
  105. has shape (bs, num_feat_points).
  106. feat_pos (Tensor): The positional embeddings of the features, has
  107. shape (bs, num_feat_points, dim).
  108. Returns:
  109. dict: The dictionary of encoder outputs, which includes the
  110. `memory` of the encoder output.
  111. """
  112. memory = self.encoder(
  113. query=feat, query_pos=feat_pos,
  114. key_padding_mask=feat_mask) # for self_attn
  115. encoder_outputs_dict = dict(memory=memory)
  116. return encoder_outputs_dict
  117. def pre_decoder(self, memory: Tensor) -> Tuple[Dict, Dict]:
  118. """Prepare intermediate variables before entering Transformer decoder,
  119. such as `query`, `query_pos`.
  120. The forward procedure of the transformer is defined as:
  121. 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
  122. More details can be found at `TransformerDetector.forward_transformer`
  123. in `mmdet/detector/base_detr.py`.
  124. Args:
  125. memory (Tensor): The output embeddings of the Transformer encoder,
  126. has shape (bs, num_feat_points, dim).
  127. Returns:
  128. tuple[dict, dict]: The first dict contains the inputs of decoder
  129. and the second dict contains the inputs of the bbox_head function.
  130. - decoder_inputs_dict (dict): The keyword args dictionary of
  131. `self.forward_decoder()`, which includes 'query', 'query_pos',
  132. 'memory'.
  133. - head_inputs_dict (dict): The keyword args dictionary of the
  134. bbox_head functions, which is usually empty, or includes
  135. `enc_outputs_class` and `enc_outputs_class` when the detector
  136. support 'two stage' or 'query selection' strategies.
  137. """
  138. batch_size = memory.size(0) # (bs, num_feat_points, dim)
  139. query_pos = self.query_embedding.weight
  140. # (num_queries, dim) -> (bs, num_queries, dim)
  141. query_pos = query_pos.unsqueeze(0).repeat(batch_size, 1, 1)
  142. query = torch.zeros_like(query_pos)
  143. decoder_inputs_dict = dict(
  144. query_pos=query_pos, query=query, memory=memory)
  145. head_inputs_dict = dict()
  146. return decoder_inputs_dict, head_inputs_dict
  147. def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor,
  148. memory_mask: Tensor, memory_pos: Tensor) -> Dict:
  149. """Forward with Transformer decoder.
  150. The forward procedure of the transformer is defined as:
  151. 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
  152. More details can be found at `TransformerDetector.forward_transformer`
  153. in `mmdet/detector/base_detr.py`.
  154. Args:
  155. query (Tensor): The queries of decoder inputs, has shape
  156. (bs, num_queries, dim).
  157. query_pos (Tensor): The positional queries of decoder inputs,
  158. has shape (bs, num_queries, dim).
  159. memory (Tensor): The output embeddings of the Transformer encoder,
  160. has shape (bs, num_feat_points, dim).
  161. memory_mask (Tensor): ByteTensor, the padding mask of the memory,
  162. has shape (bs, num_feat_points).
  163. memory_pos (Tensor): The positional embeddings of memory, has
  164. shape (bs, num_feat_points, dim).
  165. Returns:
  166. dict: The dictionary of decoder outputs, which includes the
  167. `hidden_states` of the decoder output.
  168. - hidden_states (Tensor): Has shape
  169. (num_decoder_layers, bs, num_queries, dim)
  170. """
  171. hidden_states = self.decoder(
  172. query=query,
  173. key=memory,
  174. value=memory,
  175. query_pos=query_pos,
  176. key_pos=memory_pos,
  177. key_padding_mask=memory_mask) # for cross_attn
  178. head_inputs_dict = dict(hidden_states=hidden_states)
  179. return head_inputs_dict