123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Optional, Tuple, Union
- import torch
- from mmcv.cnn import build_norm_layer
- from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
- from mmcv.ops import MultiScaleDeformableAttention
- from mmengine.model import ModuleList
- from torch import Tensor, nn
- from .detr_layers import (DetrTransformerDecoder, DetrTransformerDecoderLayer,
- DetrTransformerEncoder, DetrTransformerEncoderLayer)
- from .utils import inverse_sigmoid
- class DeformableDetrTransformerEncoder(DetrTransformerEncoder):
- """Transformer encoder of Deformable DETR."""
- def _init_layers(self) -> None:
- """Initialize encoder layers."""
- self.layers = ModuleList([
- DeformableDetrTransformerEncoderLayer(**self.layer_cfg)
- for _ in range(self.num_layers)
- ])
- self.embed_dims = self.layers[0].embed_dims
- def forward(self, query: Tensor, query_pos: Tensor,
- key_padding_mask: Tensor, spatial_shapes: Tensor,
- level_start_index: Tensor, valid_ratios: Tensor,
- **kwargs) -> Tensor:
- """Forward function of Transformer encoder.
- Args:
- query (Tensor): The input query, has shape (bs, num_queries, dim).
- query_pos (Tensor): The positional encoding for query, has shape
- (bs, num_queries, dim).
- key_padding_mask (Tensor): The `key_padding_mask` of `self_attn`
- input. ByteTensor, has shape (bs, num_queries).
- spatial_shapes (Tensor): Spatial shapes of features in all levels,
- has shape (num_levels, 2), last dimension represents (h, w).
- level_start_index (Tensor): The start index of each level.
- A tensor has shape (num_levels, ) and can be represented
- as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
- valid_ratios (Tensor): The ratios of the valid width and the valid
- height relative to the width and the height of features in all
- levels, has shape (bs, num_levels, 2).
- Returns:
- Tensor: Output queries of Transformer encoder, which is also
- called 'encoder output embeddings' or 'memory', has shape
- (bs, num_queries, dim)
- """
- reference_points = self.get_encoder_reference_points(
- spatial_shapes, valid_ratios, device=query.device)
- for layer in self.layers:
- query = layer(
- query=query,
- query_pos=query_pos,
- key_padding_mask=key_padding_mask,
- spatial_shapes=spatial_shapes,
- level_start_index=level_start_index,
- valid_ratios=valid_ratios,
- reference_points=reference_points,
- **kwargs)
- return query
- @staticmethod
- def get_encoder_reference_points(
- spatial_shapes: Tensor, valid_ratios: Tensor,
- device: Union[torch.device, str]) -> Tensor:
- """Get the reference points used in encoder.
- Args:
- spatial_shapes (Tensor): Spatial shapes of features in all levels,
- has shape (num_levels, 2), last dimension represents (h, w).
- valid_ratios (Tensor): The ratios of the valid width and the valid
- height relative to the width and the height of features in all
- levels, has shape (bs, num_levels, 2).
- device (obj:`device` or str): The device acquired by the
- `reference_points`.
- Returns:
- Tensor: Reference points used in decoder, has shape (bs, length,
- num_levels, 2).
- """
- reference_points_list = []
- for lvl, (H, W) in enumerate(spatial_shapes):
- ref_y, ref_x = torch.meshgrid(
- torch.linspace(
- 0.5, H - 0.5, H, dtype=torch.float32, device=device),
- torch.linspace(
- 0.5, W - 0.5, W, dtype=torch.float32, device=device))
- ref_y = ref_y.reshape(-1)[None] / (
- valid_ratios[:, None, lvl, 1] * H)
- ref_x = ref_x.reshape(-1)[None] / (
- valid_ratios[:, None, lvl, 0] * W)
- ref = torch.stack((ref_x, ref_y), -1)
- reference_points_list.append(ref)
- reference_points = torch.cat(reference_points_list, 1)
- # [bs, sum(hw), num_level, 2]
- reference_points = reference_points[:, :, None] * valid_ratios[:, None]
- return reference_points
- class DeformableDetrTransformerDecoder(DetrTransformerDecoder):
- """Transformer Decoder of Deformable DETR."""
- def _init_layers(self) -> None:
- """Initialize decoder layers."""
- self.layers = ModuleList([
- DeformableDetrTransformerDecoderLayer(**self.layer_cfg)
- for _ in range(self.num_layers)
- ])
- self.embed_dims = self.layers[0].embed_dims
- if self.post_norm_cfg is not None:
- raise ValueError('There is not post_norm in '
- f'{self._get_name()}')
- def forward(self,
- query: Tensor,
- query_pos: Tensor,
- value: Tensor,
- key_padding_mask: Tensor,
- reference_points: Tensor,
- spatial_shapes: Tensor,
- level_start_index: Tensor,
- valid_ratios: Tensor,
- reg_branches: Optional[nn.Module] = None,
- **kwargs) -> Tuple[Tensor]:
- """Forward function of Transformer decoder.
- Args:
- query (Tensor): The input queries, has shape (bs, num_queries,
- dim).
- query_pos (Tensor): The input positional query, has shape
- (bs, num_queries, dim). It will be added to `query` before
- forward function.
- value (Tensor): The input values, has shape (bs, num_value, dim).
- key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn`
- input. ByteTensor, has shape (bs, num_value).
- reference_points (Tensor): The initial reference, has shape
- (bs, num_queries, 4) with the last dimension arranged as
- (cx, cy, w, h) when `as_two_stage` is `True`, otherwise has
- shape (bs, num_queries, 2) with the last dimension arranged
- as (cx, cy).
- spatial_shapes (Tensor): Spatial shapes of features in all levels,
- has shape (num_levels, 2), last dimension represents (h, w).
- level_start_index (Tensor): The start index of each level.
- A tensor has shape (num_levels, ) and can be represented
- as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
- valid_ratios (Tensor): The ratios of the valid width and the valid
- height relative to the width and the height of features in all
- levels, has shape (bs, num_levels, 2).
- reg_branches: (obj:`nn.ModuleList`, optional): Used for refining
- the regression results. Only would be passed when
- `with_box_refine` is `True`, otherwise would be `None`.
- Returns:
- tuple[Tensor]: Outputs of Deformable Transformer Decoder.
- - output (Tensor): Output embeddings of the last decoder, has
- shape (num_queries, bs, embed_dims) when `return_intermediate`
- is `False`. Otherwise, Intermediate output embeddings of all
- decoder layers, has shape (num_decoder_layers, num_queries, bs,
- embed_dims).
- - reference_points (Tensor): The reference of the last decoder
- layer, has shape (bs, num_queries, 4) when `return_intermediate`
- is `False`. Otherwise, Intermediate references of all decoder
- layers, has shape (num_decoder_layers, bs, num_queries, 4). The
- coordinates are arranged as (cx, cy, w, h)
- """
- output = query
- intermediate = []
- intermediate_reference_points = []
- for layer_id, layer in enumerate(self.layers):
- if reference_points.shape[-1] == 4:
- reference_points_input = \
- reference_points[:, :, None] * \
- torch.cat([valid_ratios, valid_ratios], -1)[:, None]
- else:
- assert reference_points.shape[-1] == 2
- reference_points_input = \
- reference_points[:, :, None] * \
- valid_ratios[:, None]
- output = layer(
- output,
- query_pos=query_pos,
- value=value,
- key_padding_mask=key_padding_mask,
- spatial_shapes=spatial_shapes,
- level_start_index=level_start_index,
- valid_ratios=valid_ratios,
- reference_points=reference_points_input,
- **kwargs)
- if reg_branches is not None:
- tmp_reg_preds = reg_branches[layer_id](output)
- if reference_points.shape[-1] == 4:
- new_reference_points = tmp_reg_preds + inverse_sigmoid(
- reference_points)
- new_reference_points = new_reference_points.sigmoid()
- else:
- assert reference_points.shape[-1] == 2
- new_reference_points = tmp_reg_preds
- new_reference_points[..., :2] = tmp_reg_preds[
- ..., :2] + inverse_sigmoid(reference_points)
- new_reference_points = new_reference_points.sigmoid()
- reference_points = new_reference_points.detach()
- if self.return_intermediate:
- intermediate.append(output)
- intermediate_reference_points.append(reference_points)
- if self.return_intermediate:
- return torch.stack(intermediate), torch.stack(
- intermediate_reference_points)
- return output, reference_points
- class DeformableDetrTransformerEncoderLayer(DetrTransformerEncoderLayer):
- """Encoder layer of Deformable DETR."""
- def _init_layers(self) -> None:
- """Initialize self_attn, ffn, and norms."""
- self.self_attn = MultiScaleDeformableAttention(**self.self_attn_cfg)
- self.embed_dims = self.self_attn.embed_dims
- self.ffn = FFN(**self.ffn_cfg)
- norms_list = [
- build_norm_layer(self.norm_cfg, self.embed_dims)[1]
- for _ in range(2)
- ]
- self.norms = ModuleList(norms_list)
- class DeformableDetrTransformerDecoderLayer(DetrTransformerDecoderLayer):
- """Decoder layer of Deformable DETR."""
- def _init_layers(self) -> None:
- """Initialize self_attn, cross-attn, ffn, and norms."""
- self.self_attn = MultiheadAttention(**self.self_attn_cfg)
- self.cross_attn = MultiScaleDeformableAttention(**self.cross_attn_cfg)
- self.embed_dims = self.self_attn.embed_dims
- self.ffn = FFN(**self.ffn_cfg)
- norms_list = [
- build_norm_layer(self.norm_cfg, self.embed_dims)[1]
- for _ in range(3)
- ]
- self.norms = ModuleList(norms_list)
|