123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Dict, Tuple
- import torch
- import torch.nn.functional as F
- from torch import Tensor, nn
- from mmdet.registry import MODELS
- from mmdet.structures import OptSampleList
- from ..layers import (DetrTransformerDecoder, DetrTransformerEncoder,
- SinePositionalEncoding)
- from .base_detr import DetectionTransformer
- @MODELS.register_module()
- class DETR(DetectionTransformer):
- r"""Implementation of `DETR: End-to-End Object Detection with Transformers.
- <https://arxiv.org/pdf/2005.12872>`_.
- Code is modified from the `official github repo
- <https://github.com/facebookresearch/detr>`_.
- """
- def _init_layers(self) -> None:
- """Initialize layers except for backbone, neck and bbox_head."""
- self.positional_encoding = SinePositionalEncoding(
- **self.positional_encoding)
- self.encoder = DetrTransformerEncoder(**self.encoder)
- self.decoder = DetrTransformerDecoder(**self.decoder)
- self.embed_dims = self.encoder.embed_dims
- # NOTE The embed_dims is typically passed from the inside out.
- # For example in DETR, The embed_dims is passed as
- # self_attn -> the first encoder layer -> encoder -> detector.
- self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims)
- num_feats = self.positional_encoding.num_feats
- assert num_feats * 2 == self.embed_dims, \
- 'embed_dims should be exactly 2 times of num_feats. ' \
- f'Found {self.embed_dims} and {num_feats}.'
- def init_weights(self) -> None:
- """Initialize weights for Transformer and other components."""
- super().init_weights()
- for coder in self.encoder, self.decoder:
- for p in coder.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform_(p)
- def pre_transformer(
- self,
- img_feats: Tuple[Tensor],
- batch_data_samples: OptSampleList = None) -> Tuple[Dict, Dict]:
- """Prepare the inputs of the Transformer.
- The forward procedure of the transformer is defined as:
- 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
- More details can be found at `TransformerDetector.forward_transformer`
- in `mmdet/detector/base_detr.py`.
- Args:
- img_feats (Tuple[Tensor]): Tuple of features output from the neck,
- has shape (bs, c, h, w).
- batch_data_samples (List[:obj:`DetDataSample`]): The batch
- data samples. It usually includes information such as
- `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
- Defaults to None.
- Returns:
- tuple[dict, dict]: The first dict contains the inputs of encoder
- and the second dict contains the inputs of decoder.
- - encoder_inputs_dict (dict): The keyword args dictionary of
- `self.forward_encoder()`, which includes 'feat', 'feat_mask',
- and 'feat_pos'.
- - decoder_inputs_dict (dict): The keyword args dictionary of
- `self.forward_decoder()`, which includes 'memory_mask',
- and 'memory_pos'.
- """
- feat = img_feats[-1] # NOTE img_feats contains only one feature.
- batch_size, feat_dim, _, _ = feat.shape
- # construct binary masks which for the transformer.
- assert batch_data_samples is not None
- batch_input_shape = batch_data_samples[0].batch_input_shape
- img_shape_list = [sample.img_shape for sample in batch_data_samples]
- input_img_h, input_img_w = batch_input_shape
- masks = feat.new_ones((batch_size, input_img_h, input_img_w))
- for img_id in range(batch_size):
- img_h, img_w = img_shape_list[img_id]
- masks[img_id, :img_h, :img_w] = 0
- # NOTE following the official DETR repo, non-zero values represent
- # ignored positions, while zero values mean valid positions.
- masks = F.interpolate(
- masks.unsqueeze(1), size=feat.shape[-2:]).to(torch.bool).squeeze(1)
- # [batch_size, embed_dim, h, w]
- pos_embed = self.positional_encoding(masks)
- # use `view` instead of `flatten` for dynamically exporting to ONNX
- # [bs, c, h, w] -> [bs, h*w, c]
- feat = feat.view(batch_size, feat_dim, -1).permute(0, 2, 1)
- pos_embed = pos_embed.view(batch_size, feat_dim, -1).permute(0, 2, 1)
- # [bs, h, w] -> [bs, h*w]
- masks = masks.view(batch_size, -1)
- # prepare transformer_inputs_dict
- encoder_inputs_dict = dict(
- feat=feat, feat_mask=masks, feat_pos=pos_embed)
- decoder_inputs_dict = dict(memory_mask=masks, memory_pos=pos_embed)
- return encoder_inputs_dict, decoder_inputs_dict
- def forward_encoder(self, feat: Tensor, feat_mask: Tensor,
- feat_pos: Tensor) -> Dict:
- """Forward with Transformer encoder.
- The forward procedure of the transformer is defined as:
- 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
- More details can be found at `TransformerDetector.forward_transformer`
- in `mmdet/detector/base_detr.py`.
- Args:
- feat (Tensor): Sequential features, has shape (bs, num_feat_points,
- dim).
- feat_mask (Tensor): ByteTensor, the padding mask of the features,
- has shape (bs, num_feat_points).
- feat_pos (Tensor): The positional embeddings of the features, has
- shape (bs, num_feat_points, dim).
- Returns:
- dict: The dictionary of encoder outputs, which includes the
- `memory` of the encoder output.
- """
- memory = self.encoder(
- query=feat, query_pos=feat_pos,
- key_padding_mask=feat_mask) # for self_attn
- encoder_outputs_dict = dict(memory=memory)
- return encoder_outputs_dict
- def pre_decoder(self, memory: Tensor) -> Tuple[Dict, Dict]:
- """Prepare intermediate variables before entering Transformer decoder,
- such as `query`, `query_pos`.
- The forward procedure of the transformer is defined as:
- 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
- More details can be found at `TransformerDetector.forward_transformer`
- in `mmdet/detector/base_detr.py`.
- Args:
- memory (Tensor): The output embeddings of the Transformer encoder,
- has shape (bs, num_feat_points, dim).
- Returns:
- tuple[dict, dict]: The first dict contains the inputs of decoder
- and the second dict contains the inputs of the bbox_head function.
- - decoder_inputs_dict (dict): The keyword args dictionary of
- `self.forward_decoder()`, which includes 'query', 'query_pos',
- 'memory'.
- - head_inputs_dict (dict): The keyword args dictionary of the
- bbox_head functions, which is usually empty, or includes
- `enc_outputs_class` and `enc_outputs_class` when the detector
- support 'two stage' or 'query selection' strategies.
- """
- batch_size = memory.size(0) # (bs, num_feat_points, dim)
- query_pos = self.query_embedding.weight
- # (num_queries, dim) -> (bs, num_queries, dim)
- query_pos = query_pos.unsqueeze(0).repeat(batch_size, 1, 1)
- query = torch.zeros_like(query_pos)
- decoder_inputs_dict = dict(
- query_pos=query_pos, query=query, memory=memory)
- head_inputs_dict = dict()
- return decoder_inputs_dict, head_inputs_dict
- def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor,
- memory_mask: Tensor, memory_pos: Tensor) -> Dict:
- """Forward with Transformer decoder.
- The forward procedure of the transformer is defined as:
- 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
- More details can be found at `TransformerDetector.forward_transformer`
- in `mmdet/detector/base_detr.py`.
- Args:
- query (Tensor): The queries of decoder inputs, has shape
- (bs, num_queries, dim).
- query_pos (Tensor): The positional queries of decoder inputs,
- has shape (bs, num_queries, dim).
- memory (Tensor): The output embeddings of the Transformer encoder,
- has shape (bs, num_feat_points, dim).
- memory_mask (Tensor): ByteTensor, the padding mask of the memory,
- has shape (bs, num_feat_points).
- memory_pos (Tensor): The positional embeddings of memory, has
- shape (bs, num_feat_points, dim).
- Returns:
- dict: The dictionary of decoder outputs, which includes the
- `hidden_states` of the decoder output.
- - hidden_states (Tensor): Has shape
- (num_decoder_layers, bs, num_queries, dim)
- """
- hidden_states = self.decoder(
- query=query,
- key=memory,
- value=memory,
- query_pos=query_pos,
- key_pos=memory_pos,
- key_padding_mask=memory_mask) # for cross_attn
- head_inputs_dict = dict(hidden_states=hidden_states)
- return head_inputs_dict
|