123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Dict, Tuple
- from mmengine.model import uniform_init
- from torch import Tensor, nn
- from mmdet.registry import MODELS
- from ..layers import SinePositionalEncoding
- from ..layers.transformer import (DABDetrTransformerDecoder,
- DABDetrTransformerEncoder, inverse_sigmoid)
- from .detr import DETR
- @MODELS.register_module()
- class DABDETR(DETR):
- r"""Implementation of `DAB-DETR:
- Dynamic Anchor Boxes are Better Queries for DETR.
- <https://arxiv.org/abs/2201.12329>`_.
- Code is modified from the `official github repo
- <https://github.com/IDEA-Research/DAB-DETR>`_.
- Args:
- with_random_refpoints (bool): Whether to randomly initialize query
- embeddings and not update them during training.
- Defaults to False.
- num_patterns (int): Inspired by Anchor-DETR. Defaults to 0.
- """
- def __init__(self,
- *args,
- with_random_refpoints: bool = False,
- num_patterns: int = 0,
- **kwargs) -> None:
- self.with_random_refpoints = with_random_refpoints
- assert isinstance(num_patterns, int), \
- f'num_patterns should be int but {num_patterns}.'
- self.num_patterns = num_patterns
- super().__init__(*args, **kwargs)
- def _init_layers(self) -> None:
- """Initialize layers except for backbone, neck and bbox_head."""
- self.positional_encoding = SinePositionalEncoding(
- **self.positional_encoding)
- self.encoder = DABDetrTransformerEncoder(**self.encoder)
- self.decoder = DABDetrTransformerDecoder(**self.decoder)
- self.embed_dims = self.encoder.embed_dims
- self.query_dim = self.decoder.query_dim
- self.query_embedding = nn.Embedding(self.num_queries, self.query_dim)
- if self.num_patterns > 0:
- self.patterns = nn.Embedding(self.num_patterns, self.embed_dims)
- num_feats = self.positional_encoding.num_feats
- assert num_feats * 2 == self.embed_dims, \
- f'embed_dims should be exactly 2 times of num_feats. ' \
- f'Found {self.embed_dims} and {num_feats}.'
- def init_weights(self) -> None:
- """Initialize weights for Transformer and other components."""
- super(DABDETR, self).init_weights()
- if self.with_random_refpoints:
- uniform_init(self.query_embedding)
- self.query_embedding.weight.data[:, :2] = \
- inverse_sigmoid(self.query_embedding.weight.data[:, :2])
- self.query_embedding.weight.data[:, :2].requires_grad = False
- def pre_decoder(self, memory: Tensor) -> Tuple[Dict, Dict]:
- """Prepare intermediate variables before entering Transformer decoder,
- such as `query`, `query_pos`.
- Args:
- memory (Tensor): The output embeddings of the Transformer encoder,
- has shape (bs, num_feat_points, dim).
- Returns:
- tuple[dict, dict]: The first dict contains the inputs of decoder
- and the second dict contains the inputs of the bbox_head function.
- - decoder_inputs_dict (dict): The keyword args dictionary of
- `self.forward_decoder()`, which includes 'query', 'query_pos',
- 'memory' and 'reg_branches'.
- - head_inputs_dict (dict): The keyword args dictionary of the
- bbox_head functions, which is usually empty, or includes
- `enc_outputs_class` and `enc_outputs_class` when the detector
- support 'two stage' or 'query selection' strategies.
- """
- batch_size = memory.size(0)
- query_pos = self.query_embedding.weight
- query_pos = query_pos.unsqueeze(0).repeat(batch_size, 1, 1)
- if self.num_patterns == 0:
- query = query_pos.new_zeros(batch_size, self.num_queries,
- self.embed_dims)
- else:
- query = self.patterns.weight[:, None, None, :]\
- .repeat(1, self.num_queries, batch_size, 1)\
- .view(-1, batch_size, self.embed_dims)\
- .permute(1, 0, 2)
- query_pos = query_pos.repeat(1, self.num_patterns, 1)
- decoder_inputs_dict = dict(
- query_pos=query_pos, query=query, memory=memory)
- head_inputs_dict = dict()
- return decoder_inputs_dict, head_inputs_dict
- def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor,
- memory_mask: Tensor, memory_pos: Tensor) -> Dict:
- """Forward with Transformer decoder.
- Args:
- query (Tensor): The queries of decoder inputs, has shape
- (bs, num_queries, dim).
- query_pos (Tensor): The positional queries of decoder inputs,
- has shape (bs, num_queries, dim).
- memory (Tensor): The output embeddings of the Transformer encoder,
- has shape (bs, num_feat_points, dim).
- memory_mask (Tensor): ByteTensor, the padding mask of the memory,
- has shape (bs, num_feat_points).
- memory_pos (Tensor): The positional embeddings of memory, has
- shape (bs, num_feat_points, dim).
- Returns:
- dict: The dictionary of decoder outputs, which includes the
- `hidden_states` and `references` of the decoder output.
- """
- hidden_states, references = self.decoder(
- query=query,
- key=memory,
- query_pos=query_pos,
- key_pos=memory_pos,
- key_padding_mask=memory_mask,
- reg_branches=self.bbox_head.
- fc_reg # iterative refinement for anchor boxes
- )
- head_inputs_dict = dict(
- hidden_states=hidden_states, references=references)
- return head_inputs_dict
|