conditional_detr.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict
  3. import torch.nn as nn
  4. from torch import Tensor
  5. from mmdet.registry import MODELS
  6. from ..layers import (ConditionalDetrTransformerDecoder,
  7. DetrTransformerEncoder, SinePositionalEncoding)
  8. from .detr import DETR
  9. @MODELS.register_module()
  10. class ConditionalDETR(DETR):
  11. r"""Implementation of `Conditional DETR for Fast Training Convergence.
  12. <https://arxiv.org/abs/2108.06152>`_.
  13. Code is modified from the `official github repo
  14. <https://github.com/Atten4Vis/ConditionalDETR>`_.
  15. """
  16. def _init_layers(self) -> None:
  17. """Initialize layers except for backbone, neck and bbox_head."""
  18. self.positional_encoding = SinePositionalEncoding(
  19. **self.positional_encoding)
  20. self.encoder = DetrTransformerEncoder(**self.encoder)
  21. self.decoder = ConditionalDetrTransformerDecoder(**self.decoder)
  22. self.embed_dims = self.encoder.embed_dims
  23. # NOTE The embed_dims is typically passed from the inside out.
  24. # For example in DETR, The embed_dims is passed as
  25. # self_attn -> the first encoder layer -> encoder -> detector.
  26. self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims)
  27. num_feats = self.positional_encoding.num_feats
  28. assert num_feats * 2 == self.embed_dims, \
  29. f'embed_dims should be exactly 2 times of num_feats. ' \
  30. f'Found {self.embed_dims} and {num_feats}.'
  31. def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor,
  32. memory_mask: Tensor, memory_pos: Tensor) -> Dict:
  33. """Forward with Transformer decoder.
  34. Args:
  35. query (Tensor): The queries of decoder inputs, has shape
  36. (bs, num_queries, dim).
  37. query_pos (Tensor): The positional queries of decoder inputs,
  38. has shape (bs, num_queries, dim).
  39. memory (Tensor): The output embeddings of the Transformer encoder,
  40. has shape (bs, num_feat_points, dim).
  41. memory_mask (Tensor): ByteTensor, the padding mask of the memory,
  42. has shape (bs, num_feat_points).
  43. memory_pos (Tensor): The positional embeddings of memory, has
  44. shape (bs, num_feat_points, dim).
  45. Returns:
  46. dict: The dictionary of decoder outputs, which includes the
  47. `hidden_states` and `references` of the decoder output.
  48. - hidden_states (Tensor): Has shape
  49. (num_decoder_layers, bs, num_queries, dim)
  50. - references (Tensor): Has shape
  51. (bs, num_queries, 2)
  52. """
  53. hidden_states, references = self.decoder(
  54. query=query,
  55. key=memory,
  56. query_pos=query_pos,
  57. key_pos=memory_pos,
  58. key_padding_mask=memory_mask)
  59. head_inputs_dict = dict(
  60. hidden_states=hidden_states, references=references)
  61. return head_inputs_dict