123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Dict, Optional, Tuple
- import torch
- from torch import Tensor, nn
- from torch.nn.init import normal_
- from mmdet.registry import MODELS
- from mmdet.structures import OptSampleList
- from mmdet.utils import OptConfigType
- from ..layers import (CdnQueryGenerator, DeformableDetrTransformerEncoder,
- DinoTransformerDecoder, SinePositionalEncoding)
- from .deformable_detr import DeformableDETR, MultiScaleDeformableAttention
- @MODELS.register_module()
- class DINO(DeformableDETR):
- r"""Implementation of `DINO: DETR with Improved DeNoising Anchor Boxes
- for End-to-End Object Detection <https://arxiv.org/abs/2203.03605>`_
- Code is modified from the `official github repo
- <https://github.com/IDEA-Research/DINO>`_.
- Args:
- dn_cfg (:obj:`ConfigDict` or dict, optional): Config of denoising
- query generator. Defaults to `None`.
- """
- def __init__(self, *args, dn_cfg: OptConfigType = None, **kwargs) -> None:
- super().__init__(*args, **kwargs)
- assert self.as_two_stage, 'as_two_stage must be True for DINO'
- assert self.with_box_refine, 'with_box_refine must be True for DINO'
- if dn_cfg is not None:
- assert 'num_classes' not in dn_cfg and \
- 'num_queries' not in dn_cfg and \
- 'hidden_dim' not in dn_cfg, \
- 'The three keyword args `num_classes`, `embed_dims`, and ' \
- '`num_matching_queries` are set in `detector.__init__()`, ' \
- 'users should not set them in `dn_cfg` config.'
- dn_cfg['num_classes'] = self.bbox_head.num_classes
- dn_cfg['embed_dims'] = self.embed_dims
- dn_cfg['num_matching_queries'] = self.num_queries
- self.dn_query_generator = CdnQueryGenerator(**dn_cfg)
- def _init_layers(self) -> None:
- """Initialize layers except for backbone, neck and bbox_head."""
- self.positional_encoding = SinePositionalEncoding(
- **self.positional_encoding)
- self.encoder = DeformableDetrTransformerEncoder(**self.encoder)
- self.decoder = DinoTransformerDecoder(**self.decoder)
- self.embed_dims = self.encoder.embed_dims
- self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims)
- # NOTE In DINO, the query_embedding only contains content
- # queries, while in Deformable DETR, the query_embedding
- # contains both content and spatial queries, and in DETR,
- # it only contains spatial queries.
- num_feats = self.positional_encoding.num_feats
- assert num_feats * 2 == self.embed_dims, \
- f'embed_dims should be exactly 2 times of num_feats. ' \
- f'Found {self.embed_dims} and {num_feats}.'
- self.level_embed = nn.Parameter(
- torch.Tensor(self.num_feature_levels, self.embed_dims))
- self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims)
- self.memory_trans_norm = nn.LayerNorm(self.embed_dims)
- def init_weights(self) -> None:
- """Initialize weights for Transformer and other components."""
- super(DeformableDETR, self).init_weights()
- for coder in self.encoder, self.decoder:
- for p in coder.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform_(p)
- for m in self.modules():
- if isinstance(m, MultiScaleDeformableAttention):
- m.init_weights()
- nn.init.xavier_uniform_(self.memory_trans_fc.weight)
- nn.init.xavier_uniform_(self.query_embedding.weight)
- normal_(self.level_embed)
- def forward_transformer(
- self,
- img_feats: Tuple[Tensor],
- batch_data_samples: OptSampleList = None,
- ) -> Dict:
- """Forward process of Transformer.
- The forward procedure of the transformer is defined as:
- 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
- More details can be found at `TransformerDetector.forward_transformer`
- in `mmdet/detector/base_detr.py`.
- The difference is that the ground truth in `batch_data_samples` is
- required for the `pre_decoder` to prepare the query of DINO.
- Additionally, DINO inherits the `pre_transformer` method and the
- `forward_encoder` method of DeformableDETR. More details about the
- two methods can be found in `mmdet/detector/deformable_detr.py`.
- Args:
- img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each
- feature map has shape (bs, dim, H, W).
- batch_data_samples (list[:obj:`DetDataSample`]): The batch
- data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
- Defaults to None.
- Returns:
- dict: The dictionary of bbox_head function inputs, which always
- includes the `hidden_states` of the decoder output and may contain
- `references` including the initial and intermediate references.
- """
- encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer(
- img_feats, batch_data_samples)
- encoder_outputs_dict = self.forward_encoder(**encoder_inputs_dict)
- tmp_dec_in, head_inputs_dict = self.pre_decoder(
- **encoder_outputs_dict, batch_data_samples=batch_data_samples)
- decoder_inputs_dict.update(tmp_dec_in)
- decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict)
- head_inputs_dict.update(decoder_outputs_dict)
- return head_inputs_dict
- def pre_decoder(
- self,
- memory: Tensor,
- memory_mask: Tensor,
- spatial_shapes: Tensor,
- batch_data_samples: OptSampleList = None,
- ) -> Tuple[Dict]:
- """Prepare intermediate variables before entering Transformer decoder,
- such as `query`, `query_pos`, and `reference_points`.
- Args:
- memory (Tensor): The output embeddings of the Transformer encoder,
- has shape (bs, num_feat_points, dim).
- memory_mask (Tensor): ByteTensor, the padding mask of the memory,
- has shape (bs, num_feat_points). Will only be used when
- `as_two_stage` is `True`.
- spatial_shapes (Tensor): Spatial shapes of features in all levels.
- With shape (num_levels, 2), last dimension represents (h, w).
- Will only be used when `as_two_stage` is `True`.
- batch_data_samples (list[:obj:`DetDataSample`]): The batch
- data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
- Defaults to None.
- Returns:
- tuple[dict]: The decoder_inputs_dict and head_inputs_dict.
- - decoder_inputs_dict (dict): The keyword dictionary args of
- `self.forward_decoder()`, which includes 'query', 'memory',
- `reference_points`, and `dn_mask`. The reference points of
- decoder input here are 4D boxes, although it has `points`
- in its name.
- - head_inputs_dict (dict): The keyword dictionary args of the
- bbox_head functions, which includes `topk_score`, `topk_coords`,
- and `dn_meta` when `self.training` is `True`, else is empty.
- """
- bs, _, c = memory.shape
- cls_out_features = self.bbox_head.cls_branches[
- self.decoder.num_layers].out_features
- output_memory, output_proposals = self.gen_encoder_output_proposals(
- memory, memory_mask, spatial_shapes)
- enc_outputs_class = self.bbox_head.cls_branches[
- self.decoder.num_layers](
- output_memory)
- enc_outputs_coord_unact = self.bbox_head.reg_branches[
- self.decoder.num_layers](output_memory) + output_proposals
- # NOTE The DINO selects top-k proposals according to scores of
- # multi-class classification, while DeformDETR, where the input
- # is `enc_outputs_class[..., 0]` selects according to scores of
- # binary classification.
- topk_indices = torch.topk(
- enc_outputs_class.max(-1)[0], k=self.num_queries, dim=1)[1]
- topk_score = torch.gather(
- enc_outputs_class, 1,
- topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features))
- topk_coords_unact = torch.gather(
- enc_outputs_coord_unact, 1,
- topk_indices.unsqueeze(-1).repeat(1, 1, 4))
- topk_coords = topk_coords_unact.sigmoid()
- topk_coords_unact = topk_coords_unact.detach()
- query = self.query_embedding.weight[:, None, :]
- query = query.repeat(1, bs, 1).transpose(0, 1)
- if self.training:
- dn_label_query, dn_bbox_query, dn_mask, dn_meta = \
- self.dn_query_generator(batch_data_samples)
- query = torch.cat([dn_label_query, query], dim=1)
- reference_points = torch.cat([dn_bbox_query, topk_coords_unact],
- dim=1)
- else:
- reference_points = topk_coords_unact
- dn_mask, dn_meta = None, None
- reference_points = reference_points.sigmoid()
- decoder_inputs_dict = dict(
- query=query,
- memory=memory,
- reference_points=reference_points,
- dn_mask=dn_mask)
- # NOTE DINO calculates encoder losses on scores and coordinates
- # of selected top-k encoder queries, while DeformDETR is of all
- # encoder queries.
- head_inputs_dict = dict(
- enc_outputs_class=topk_score,
- enc_outputs_coord=topk_coords,
- dn_meta=dn_meta) if self.training else dict()
- return decoder_inputs_dict, head_inputs_dict
- def forward_decoder(self,
- query: Tensor,
- memory: Tensor,
- memory_mask: Tensor,
- reference_points: Tensor,
- spatial_shapes: Tensor,
- level_start_index: Tensor,
- valid_ratios: Tensor,
- dn_mask: Optional[Tensor] = None) -> Dict:
- """Forward with Transformer decoder.
- The forward procedure of the transformer is defined as:
- 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
- More details can be found at `TransformerDetector.forward_transformer`
- in `mmdet/detector/base_detr.py`.
- Args:
- query (Tensor): The queries of decoder inputs, has shape
- (bs, num_queries_total, dim), where `num_queries_total` is the
- sum of `num_denoising_queries` and `num_matching_queries` when
- `self.training` is `True`, else `num_matching_queries`.
- memory (Tensor): The output embeddings of the Transformer encoder,
- has shape (bs, num_feat_points, dim).
- memory_mask (Tensor): ByteTensor, the padding mask of the memory,
- has shape (bs, num_feat_points).
- reference_points (Tensor): The initial reference, has shape
- (bs, num_queries_total, 4) with the last dimension arranged as
- (cx, cy, w, h).
- spatial_shapes (Tensor): Spatial shapes of features in all levels,
- has shape (num_levels, 2), last dimension represents (h, w).
- level_start_index (Tensor): The start index of each level.
- A tensor has shape (num_levels, ) and can be represented
- as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
- valid_ratios (Tensor): The ratios of the valid width and the valid
- height relative to the width and the height of features in all
- levels, has shape (bs, num_levels, 2).
- dn_mask (Tensor, optional): The attention mask to prevent
- information leakage from different denoising groups and
- matching parts, will be used as `self_attn_mask` of the
- `self.decoder`, has shape (num_queries_total,
- num_queries_total).
- It is `None` when `self.training` is `False`.
- Returns:
- dict: The dictionary of decoder outputs, which includes the
- `hidden_states` of the decoder output and `references` including
- the initial and intermediate reference_points.
- """
- inter_states, references = self.decoder(
- query=query,
- value=memory,
- key_padding_mask=memory_mask,
- self_attn_mask=dn_mask,
- reference_points=reference_points,
- spatial_shapes=spatial_shapes,
- level_start_index=level_start_index,
- valid_ratios=valid_ratios,
- reg_branches=self.bbox_head.reg_branches)
- if len(query) == self.num_queries:
- # NOTE: This is to make sure label_embeding can be involved to
- # produce loss even if there is no denoising query (no ground truth
- # target in this GPU), otherwise, this will raise runtime error in
- # distributed training.
- inter_states[0] += \
- self.dn_query_generator.label_embedding.weight[0, 0] * 0.0
- decoder_outputs_dict = dict(
- hidden_states=inter_states, references=list(references))
- return decoder_outputs_dict
|