dab_detr.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, Tuple
  3. from mmengine.model import uniform_init
  4. from torch import Tensor, nn
  5. from mmdet.registry import MODELS
  6. from ..layers import SinePositionalEncoding
  7. from ..layers.transformer import (DABDetrTransformerDecoder,
  8. DABDetrTransformerEncoder, inverse_sigmoid)
  9. from .detr import DETR
  10. @MODELS.register_module()
  11. class DABDETR(DETR):
  12. r"""Implementation of `DAB-DETR:
  13. Dynamic Anchor Boxes are Better Queries for DETR.
  14. <https://arxiv.org/abs/2201.12329>`_.
  15. Code is modified from the `official github repo
  16. <https://github.com/IDEA-Research/DAB-DETR>`_.
  17. Args:
  18. with_random_refpoints (bool): Whether to randomly initialize query
  19. embeddings and not update them during training.
  20. Defaults to False.
  21. num_patterns (int): Inspired by Anchor-DETR. Defaults to 0.
  22. """
  23. def __init__(self,
  24. *args,
  25. with_random_refpoints: bool = False,
  26. num_patterns: int = 0,
  27. **kwargs) -> None:
  28. self.with_random_refpoints = with_random_refpoints
  29. assert isinstance(num_patterns, int), \
  30. f'num_patterns should be int but {num_patterns}.'
  31. self.num_patterns = num_patterns
  32. super().__init__(*args, **kwargs)
  33. def _init_layers(self) -> None:
  34. """Initialize layers except for backbone, neck and bbox_head."""
  35. self.positional_encoding = SinePositionalEncoding(
  36. **self.positional_encoding)
  37. self.encoder = DABDetrTransformerEncoder(**self.encoder)
  38. self.decoder = DABDetrTransformerDecoder(**self.decoder)
  39. self.embed_dims = self.encoder.embed_dims
  40. self.query_dim = self.decoder.query_dim
  41. self.query_embedding = nn.Embedding(self.num_queries, self.query_dim)
  42. if self.num_patterns > 0:
  43. self.patterns = nn.Embedding(self.num_patterns, self.embed_dims)
  44. num_feats = self.positional_encoding.num_feats
  45. assert num_feats * 2 == self.embed_dims, \
  46. f'embed_dims should be exactly 2 times of num_feats. ' \
  47. f'Found {self.embed_dims} and {num_feats}.'
  48. def init_weights(self) -> None:
  49. """Initialize weights for Transformer and other components."""
  50. super(DABDETR, self).init_weights()
  51. if self.with_random_refpoints:
  52. uniform_init(self.query_embedding)
  53. self.query_embedding.weight.data[:, :2] = \
  54. inverse_sigmoid(self.query_embedding.weight.data[:, :2])
  55. self.query_embedding.weight.data[:, :2].requires_grad = False
  56. def pre_decoder(self, memory: Tensor) -> Tuple[Dict, Dict]:
  57. """Prepare intermediate variables before entering Transformer decoder,
  58. such as `query`, `query_pos`.
  59. Args:
  60. memory (Tensor): The output embeddings of the Transformer encoder,
  61. has shape (bs, num_feat_points, dim).
  62. Returns:
  63. tuple[dict, dict]: The first dict contains the inputs of decoder
  64. and the second dict contains the inputs of the bbox_head function.
  65. - decoder_inputs_dict (dict): The keyword args dictionary of
  66. `self.forward_decoder()`, which includes 'query', 'query_pos',
  67. 'memory' and 'reg_branches'.
  68. - head_inputs_dict (dict): The keyword args dictionary of the
  69. bbox_head functions, which is usually empty, or includes
  70. `enc_outputs_class` and `enc_outputs_class` when the detector
  71. support 'two stage' or 'query selection' strategies.
  72. """
  73. batch_size = memory.size(0)
  74. query_pos = self.query_embedding.weight
  75. query_pos = query_pos.unsqueeze(0).repeat(batch_size, 1, 1)
  76. if self.num_patterns == 0:
  77. query = query_pos.new_zeros(batch_size, self.num_queries,
  78. self.embed_dims)
  79. else:
  80. query = self.patterns.weight[:, None, None, :]\
  81. .repeat(1, self.num_queries, batch_size, 1)\
  82. .view(-1, batch_size, self.embed_dims)\
  83. .permute(1, 0, 2)
  84. query_pos = query_pos.repeat(1, self.num_patterns, 1)
  85. decoder_inputs_dict = dict(
  86. query_pos=query_pos, query=query, memory=memory)
  87. head_inputs_dict = dict()
  88. return decoder_inputs_dict, head_inputs_dict
  89. def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor,
  90. memory_mask: Tensor, memory_pos: Tensor) -> Dict:
  91. """Forward with Transformer decoder.
  92. Args:
  93. query (Tensor): The queries of decoder inputs, has shape
  94. (bs, num_queries, dim).
  95. query_pos (Tensor): The positional queries of decoder inputs,
  96. has shape (bs, num_queries, dim).
  97. memory (Tensor): The output embeddings of the Transformer encoder,
  98. has shape (bs, num_feat_points, dim).
  99. memory_mask (Tensor): ByteTensor, the padding mask of the memory,
  100. has shape (bs, num_feat_points).
  101. memory_pos (Tensor): The positional embeddings of memory, has
  102. shape (bs, num_feat_points, dim).
  103. Returns:
  104. dict: The dictionary of decoder outputs, which includes the
  105. `hidden_states` and `references` of the decoder output.
  106. """
  107. hidden_states, references = self.decoder(
  108. query=query,
  109. key=memory,
  110. query_pos=query_pos,
  111. key_pos=memory_pos,
  112. key_padding_mask=memory_mask,
  113. reg_branches=self.bbox_head.
  114. fc_reg # iterative refinement for anchor boxes
  115. )
  116. head_inputs_dict = dict(
  117. hidden_states=hidden_states, references=references)
  118. return head_inputs_dict