dino.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, Optional, Tuple
  3. import torch
  4. from torch import Tensor, nn
  5. from torch.nn.init import normal_
  6. from mmdet.registry import MODELS
  7. from mmdet.structures import OptSampleList
  8. from mmdet.utils import OptConfigType
  9. from ..layers import (CdnQueryGenerator, DeformableDetrTransformerEncoder,
  10. DinoTransformerDecoder, SinePositionalEncoding)
  11. from .deformable_detr import DeformableDETR, MultiScaleDeformableAttention
  12. @MODELS.register_module()
  13. class DINO(DeformableDETR):
  14. r"""Implementation of `DINO: DETR with Improved DeNoising Anchor Boxes
  15. for End-to-End Object Detection <https://arxiv.org/abs/2203.03605>`_
  16. Code is modified from the `official github repo
  17. <https://github.com/IDEA-Research/DINO>`_.
  18. Args:
  19. dn_cfg (:obj:`ConfigDict` or dict, optional): Config of denoising
  20. query generator. Defaults to `None`.
  21. """
  22. def __init__(self, *args, dn_cfg: OptConfigType = None, **kwargs) -> None:
  23. super().__init__(*args, **kwargs)
  24. assert self.as_two_stage, 'as_two_stage must be True for DINO'
  25. assert self.with_box_refine, 'with_box_refine must be True for DINO'
  26. if dn_cfg is not None:
  27. assert 'num_classes' not in dn_cfg and \
  28. 'num_queries' not in dn_cfg and \
  29. 'hidden_dim' not in dn_cfg, \
  30. 'The three keyword args `num_classes`, `embed_dims`, and ' \
  31. '`num_matching_queries` are set in `detector.__init__()`, ' \
  32. 'users should not set them in `dn_cfg` config.'
  33. dn_cfg['num_classes'] = self.bbox_head.num_classes
  34. dn_cfg['embed_dims'] = self.embed_dims
  35. dn_cfg['num_matching_queries'] = self.num_queries
  36. self.dn_query_generator = CdnQueryGenerator(**dn_cfg)
  37. def _init_layers(self) -> None:
  38. """Initialize layers except for backbone, neck and bbox_head."""
  39. self.positional_encoding = SinePositionalEncoding(
  40. **self.positional_encoding)
  41. self.encoder = DeformableDetrTransformerEncoder(**self.encoder)
  42. self.decoder = DinoTransformerDecoder(**self.decoder)
  43. self.embed_dims = self.encoder.embed_dims
  44. self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims)
  45. # NOTE In DINO, the query_embedding only contains content
  46. # queries, while in Deformable DETR, the query_embedding
  47. # contains both content and spatial queries, and in DETR,
  48. # it only contains spatial queries.
  49. num_feats = self.positional_encoding.num_feats
  50. assert num_feats * 2 == self.embed_dims, \
  51. f'embed_dims should be exactly 2 times of num_feats. ' \
  52. f'Found {self.embed_dims} and {num_feats}.'
  53. self.level_embed = nn.Parameter(
  54. torch.Tensor(self.num_feature_levels, self.embed_dims))
  55. self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims)
  56. self.memory_trans_norm = nn.LayerNorm(self.embed_dims)
  57. def init_weights(self) -> None:
  58. """Initialize weights for Transformer and other components."""
  59. super(DeformableDETR, self).init_weights()
  60. for coder in self.encoder, self.decoder:
  61. for p in coder.parameters():
  62. if p.dim() > 1:
  63. nn.init.xavier_uniform_(p)
  64. for m in self.modules():
  65. if isinstance(m, MultiScaleDeformableAttention):
  66. m.init_weights()
  67. nn.init.xavier_uniform_(self.memory_trans_fc.weight)
  68. nn.init.xavier_uniform_(self.query_embedding.weight)
  69. normal_(self.level_embed)
  70. def forward_transformer(
  71. self,
  72. img_feats: Tuple[Tensor],
  73. batch_data_samples: OptSampleList = None,
  74. ) -> Dict:
  75. """Forward process of Transformer.
  76. The forward procedure of the transformer is defined as:
  77. 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
  78. More details can be found at `TransformerDetector.forward_transformer`
  79. in `mmdet/detector/base_detr.py`.
  80. The difference is that the ground truth in `batch_data_samples` is
  81. required for the `pre_decoder` to prepare the query of DINO.
  82. Additionally, DINO inherits the `pre_transformer` method and the
  83. `forward_encoder` method of DeformableDETR. More details about the
  84. two methods can be found in `mmdet/detector/deformable_detr.py`.
  85. Args:
  86. img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each
  87. feature map has shape (bs, dim, H, W).
  88. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  89. data samples. It usually includes information such
  90. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  91. Defaults to None.
  92. Returns:
  93. dict: The dictionary of bbox_head function inputs, which always
  94. includes the `hidden_states` of the decoder output and may contain
  95. `references` including the initial and intermediate references.
  96. """
  97. encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer(
  98. img_feats, batch_data_samples)
  99. encoder_outputs_dict = self.forward_encoder(**encoder_inputs_dict)
  100. tmp_dec_in, head_inputs_dict = self.pre_decoder(
  101. **encoder_outputs_dict, batch_data_samples=batch_data_samples)
  102. decoder_inputs_dict.update(tmp_dec_in)
  103. decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict)
  104. head_inputs_dict.update(decoder_outputs_dict)
  105. return head_inputs_dict
  106. def pre_decoder(
  107. self,
  108. memory: Tensor,
  109. memory_mask: Tensor,
  110. spatial_shapes: Tensor,
  111. batch_data_samples: OptSampleList = None,
  112. ) -> Tuple[Dict]:
  113. """Prepare intermediate variables before entering Transformer decoder,
  114. such as `query`, `query_pos`, and `reference_points`.
  115. Args:
  116. memory (Tensor): The output embeddings of the Transformer encoder,
  117. has shape (bs, num_feat_points, dim).
  118. memory_mask (Tensor): ByteTensor, the padding mask of the memory,
  119. has shape (bs, num_feat_points). Will only be used when
  120. `as_two_stage` is `True`.
  121. spatial_shapes (Tensor): Spatial shapes of features in all levels.
  122. With shape (num_levels, 2), last dimension represents (h, w).
  123. Will only be used when `as_two_stage` is `True`.
  124. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  125. data samples. It usually includes information such
  126. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  127. Defaults to None.
  128. Returns:
  129. tuple[dict]: The decoder_inputs_dict and head_inputs_dict.
  130. - decoder_inputs_dict (dict): The keyword dictionary args of
  131. `self.forward_decoder()`, which includes 'query', 'memory',
  132. `reference_points`, and `dn_mask`. The reference points of
  133. decoder input here are 4D boxes, although it has `points`
  134. in its name.
  135. - head_inputs_dict (dict): The keyword dictionary args of the
  136. bbox_head functions, which includes `topk_score`, `topk_coords`,
  137. and `dn_meta` when `self.training` is `True`, else is empty.
  138. """
  139. bs, _, c = memory.shape
  140. cls_out_features = self.bbox_head.cls_branches[
  141. self.decoder.num_layers].out_features
  142. output_memory, output_proposals = self.gen_encoder_output_proposals(
  143. memory, memory_mask, spatial_shapes)
  144. enc_outputs_class = self.bbox_head.cls_branches[
  145. self.decoder.num_layers](
  146. output_memory)
  147. enc_outputs_coord_unact = self.bbox_head.reg_branches[
  148. self.decoder.num_layers](output_memory) + output_proposals
  149. # NOTE The DINO selects top-k proposals according to scores of
  150. # multi-class classification, while DeformDETR, where the input
  151. # is `enc_outputs_class[..., 0]` selects according to scores of
  152. # binary classification.
  153. topk_indices = torch.topk(
  154. enc_outputs_class.max(-1)[0], k=self.num_queries, dim=1)[1]
  155. topk_score = torch.gather(
  156. enc_outputs_class, 1,
  157. topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features))
  158. topk_coords_unact = torch.gather(
  159. enc_outputs_coord_unact, 1,
  160. topk_indices.unsqueeze(-1).repeat(1, 1, 4))
  161. topk_coords = topk_coords_unact.sigmoid()
  162. topk_coords_unact = topk_coords_unact.detach()
  163. query = self.query_embedding.weight[:, None, :]
  164. query = query.repeat(1, bs, 1).transpose(0, 1)
  165. if self.training:
  166. dn_label_query, dn_bbox_query, dn_mask, dn_meta = \
  167. self.dn_query_generator(batch_data_samples)
  168. query = torch.cat([dn_label_query, query], dim=1)
  169. reference_points = torch.cat([dn_bbox_query, topk_coords_unact],
  170. dim=1)
  171. else:
  172. reference_points = topk_coords_unact
  173. dn_mask, dn_meta = None, None
  174. reference_points = reference_points.sigmoid()
  175. decoder_inputs_dict = dict(
  176. query=query,
  177. memory=memory,
  178. reference_points=reference_points,
  179. dn_mask=dn_mask)
  180. # NOTE DINO calculates encoder losses on scores and coordinates
  181. # of selected top-k encoder queries, while DeformDETR is of all
  182. # encoder queries.
  183. head_inputs_dict = dict(
  184. enc_outputs_class=topk_score,
  185. enc_outputs_coord=topk_coords,
  186. dn_meta=dn_meta) if self.training else dict()
  187. return decoder_inputs_dict, head_inputs_dict
  188. def forward_decoder(self,
  189. query: Tensor,
  190. memory: Tensor,
  191. memory_mask: Tensor,
  192. reference_points: Tensor,
  193. spatial_shapes: Tensor,
  194. level_start_index: Tensor,
  195. valid_ratios: Tensor,
  196. dn_mask: Optional[Tensor] = None) -> Dict:
  197. """Forward with Transformer decoder.
  198. The forward procedure of the transformer is defined as:
  199. 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
  200. More details can be found at `TransformerDetector.forward_transformer`
  201. in `mmdet/detector/base_detr.py`.
  202. Args:
  203. query (Tensor): The queries of decoder inputs, has shape
  204. (bs, num_queries_total, dim), where `num_queries_total` is the
  205. sum of `num_denoising_queries` and `num_matching_queries` when
  206. `self.training` is `True`, else `num_matching_queries`.
  207. memory (Tensor): The output embeddings of the Transformer encoder,
  208. has shape (bs, num_feat_points, dim).
  209. memory_mask (Tensor): ByteTensor, the padding mask of the memory,
  210. has shape (bs, num_feat_points).
  211. reference_points (Tensor): The initial reference, has shape
  212. (bs, num_queries_total, 4) with the last dimension arranged as
  213. (cx, cy, w, h).
  214. spatial_shapes (Tensor): Spatial shapes of features in all levels,
  215. has shape (num_levels, 2), last dimension represents (h, w).
  216. level_start_index (Tensor): The start index of each level.
  217. A tensor has shape (num_levels, ) and can be represented
  218. as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
  219. valid_ratios (Tensor): The ratios of the valid width and the valid
  220. height relative to the width and the height of features in all
  221. levels, has shape (bs, num_levels, 2).
  222. dn_mask (Tensor, optional): The attention mask to prevent
  223. information leakage from different denoising groups and
  224. matching parts, will be used as `self_attn_mask` of the
  225. `self.decoder`, has shape (num_queries_total,
  226. num_queries_total).
  227. It is `None` when `self.training` is `False`.
  228. Returns:
  229. dict: The dictionary of decoder outputs, which includes the
  230. `hidden_states` of the decoder output and `references` including
  231. the initial and intermediate reference_points.
  232. """
  233. inter_states, references = self.decoder(
  234. query=query,
  235. value=memory,
  236. key_padding_mask=memory_mask,
  237. self_attn_mask=dn_mask,
  238. reference_points=reference_points,
  239. spatial_shapes=spatial_shapes,
  240. level_start_index=level_start_index,
  241. valid_ratios=valid_ratios,
  242. reg_branches=self.bbox_head.reg_branches)
  243. if len(query) == self.num_queries:
  244. # NOTE: This is to make sure label_embeding can be involved to
  245. # produce loss even if there is no denoising query (no ground truth
  246. # target in this GPU), otherwise, this will raise runtime error in
  247. # distributed training.
  248. inter_states[0] += \
  249. self.dn_query_generator.label_embedding.weight[0, 0] * 0.0
  250. decoder_outputs_dict = dict(
  251. hidden_states=inter_states, references=list(references))
  252. return decoder_outputs_dict