# Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, Tuple from mmengine.model import uniform_init from torch import Tensor, nn from mmdet.registry import MODELS from ..layers import SinePositionalEncoding from ..layers.transformer import (DABDetrTransformerDecoder, DABDetrTransformerEncoder, inverse_sigmoid) from .detr import DETR @MODELS.register_module() class DABDETR(DETR): r"""Implementation of `DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR. `_. Code is modified from the `official github repo `_. Args: with_random_refpoints (bool): Whether to randomly initialize query embeddings and not update them during training. Defaults to False. num_patterns (int): Inspired by Anchor-DETR. Defaults to 0. """ def __init__(self, *args, with_random_refpoints: bool = False, num_patterns: int = 0, **kwargs) -> None: self.with_random_refpoints = with_random_refpoints assert isinstance(num_patterns, int), \ f'num_patterns should be int but {num_patterns}.' self.num_patterns = num_patterns super().__init__(*args, **kwargs) def _init_layers(self) -> None: """Initialize layers except for backbone, neck and bbox_head.""" self.positional_encoding = SinePositionalEncoding( **self.positional_encoding) self.encoder = DABDetrTransformerEncoder(**self.encoder) self.decoder = DABDetrTransformerDecoder(**self.decoder) self.embed_dims = self.encoder.embed_dims self.query_dim = self.decoder.query_dim self.query_embedding = nn.Embedding(self.num_queries, self.query_dim) if self.num_patterns > 0: self.patterns = nn.Embedding(self.num_patterns, self.embed_dims) num_feats = self.positional_encoding.num_feats assert num_feats * 2 == self.embed_dims, \ f'embed_dims should be exactly 2 times of num_feats. ' \ f'Found {self.embed_dims} and {num_feats}.' def init_weights(self) -> None: """Initialize weights for Transformer and other components.""" super(DABDETR, self).init_weights() if self.with_random_refpoints: uniform_init(self.query_embedding) self.query_embedding.weight.data[:, :2] = \ inverse_sigmoid(self.query_embedding.weight.data[:, :2]) self.query_embedding.weight.data[:, :2].requires_grad = False def pre_decoder(self, memory: Tensor) -> Tuple[Dict, Dict]: """Prepare intermediate variables before entering Transformer decoder, such as `query`, `query_pos`. Args: memory (Tensor): The output embeddings of the Transformer encoder, has shape (bs, num_feat_points, dim). Returns: tuple[dict, dict]: The first dict contains the inputs of decoder and the second dict contains the inputs of the bbox_head function. - decoder_inputs_dict (dict): The keyword args dictionary of `self.forward_decoder()`, which includes 'query', 'query_pos', 'memory' and 'reg_branches'. - head_inputs_dict (dict): The keyword args dictionary of the bbox_head functions, which is usually empty, or includes `enc_outputs_class` and `enc_outputs_class` when the detector support 'two stage' or 'query selection' strategies. """ batch_size = memory.size(0) query_pos = self.query_embedding.weight query_pos = query_pos.unsqueeze(0).repeat(batch_size, 1, 1) if self.num_patterns == 0: query = query_pos.new_zeros(batch_size, self.num_queries, self.embed_dims) else: query = self.patterns.weight[:, None, None, :]\ .repeat(1, self.num_queries, batch_size, 1)\ .view(-1, batch_size, self.embed_dims)\ .permute(1, 0, 2) query_pos = query_pos.repeat(1, self.num_patterns, 1) decoder_inputs_dict = dict( query_pos=query_pos, query=query, memory=memory) head_inputs_dict = dict() return decoder_inputs_dict, head_inputs_dict def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor, memory_mask: Tensor, memory_pos: Tensor) -> Dict: """Forward with Transformer decoder. Args: query (Tensor): The queries of decoder inputs, has shape (bs, num_queries, dim). query_pos (Tensor): The positional queries of decoder inputs, has shape (bs, num_queries, dim). memory (Tensor): The output embeddings of the Transformer encoder, has shape (bs, num_feat_points, dim). memory_mask (Tensor): ByteTensor, the padding mask of the memory, has shape (bs, num_feat_points). memory_pos (Tensor): The positional embeddings of memory, has shape (bs, num_feat_points, dim). Returns: dict: The dictionary of decoder outputs, which includes the `hidden_states` and `references` of the decoder output. """ hidden_states, references = self.decoder( query=query, key=memory, query_pos=query_pos, key_pos=memory_pos, key_padding_mask=memory_mask, reg_branches=self.bbox_head. fc_reg # iterative refinement for anchor boxes ) head_inputs_dict = dict( hidden_states=hidden_states, references=references) return head_inputs_dict