123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from abc import ABCMeta, abstractmethod
- from typing import Dict, List, Tuple, Union
- from torch import Tensor
- from mmdet.registry import MODELS
- from mmdet.structures import OptSampleList, SampleList
- from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
- from .base import BaseDetector
- @MODELS.register_module()
- class DetectionTransformer(BaseDetector, metaclass=ABCMeta):
- r"""Base class for Detection Transformer.
- In Detection Transformer, an encoder is used to process output features of
- neck, then several queries interact with the encoder features using a
- decoder and do the regression and classification with the bounding box
- head.
- Args:
- backbone (:obj:`ConfigDict` or dict): Config of the backbone.
- neck (:obj:`ConfigDict` or dict, optional): Config of the neck.
- Defaults to None.
- encoder (:obj:`ConfigDict` or dict, optional): Config of the
- Transformer encoder. Defaults to None.
- decoder (:obj:`ConfigDict` or dict, optional): Config of the
- Transformer decoder. Defaults to None.
- bbox_head (:obj:`ConfigDict` or dict, optional): Config for the
- bounding box head module. Defaults to None.
- positional_encoding (:obj:`ConfigDict` or dict, optional): Config
- of the positional encoding module. Defaults to None.
- num_queries (int, optional): Number of decoder query in Transformer.
- Defaults to 100.
- train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
- the bounding box head module. Defaults to None.
- test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
- the bounding box head module. Defaults to None.
- data_preprocessor (dict or ConfigDict, optional): The pre-process
- config of :class:`BaseDataPreprocessor`. it usually includes,
- ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
- Defaults to None.
- init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
- the initialization. Defaults to None.
- """
- def __init__(self,
- backbone: ConfigType,
- neck: OptConfigType = None,
- encoder: OptConfigType = None,
- decoder: OptConfigType = None,
- bbox_head: OptConfigType = None,
- positional_encoding: OptConfigType = None,
- num_queries: int = 100,
- train_cfg: OptConfigType = None,
- test_cfg: OptConfigType = None,
- data_preprocessor: OptConfigType = None,
- init_cfg: OptMultiConfig = None) -> None:
- super().__init__(
- data_preprocessor=data_preprocessor, init_cfg=init_cfg)
- # process args
- bbox_head.update(train_cfg=train_cfg)
- bbox_head.update(test_cfg=test_cfg)
- self.train_cfg = train_cfg
- self.test_cfg = test_cfg
- self.encoder = encoder
- self.decoder = decoder
- self.positional_encoding = positional_encoding
- self.num_queries = num_queries
- # init model layers
- self.backbone = MODELS.build(backbone)
- if neck is not None:
- self.neck = MODELS.build(neck)
- self.bbox_head = MODELS.build(bbox_head)
- self._init_layers()
- @abstractmethod
- def _init_layers(self) -> None:
- """Initialize layers except for backbone, neck and bbox_head."""
- pass
- def loss(self, batch_inputs: Tensor,
- batch_data_samples: SampleList) -> Union[dict, list]:
- """Calculate losses from a batch of inputs and data samples.
- Args:
- batch_inputs (Tensor): Input images of shape (bs, dim, H, W).
- These should usually be mean centered and std scaled.
- batch_data_samples (List[:obj:`DetDataSample`]): The batch
- data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
- Returns:
- dict: A dictionary of loss components
- """
- img_feats = self.extract_feat(batch_inputs)
- head_inputs_dict = self.forward_transformer(img_feats,
- batch_data_samples)
- losses = self.bbox_head.loss(
- **head_inputs_dict, batch_data_samples=batch_data_samples)
- return losses
- def predict(self,
- batch_inputs: Tensor,
- batch_data_samples: SampleList,
- rescale: bool = True) -> SampleList:
- """Predict results from a batch of inputs and data samples with post-
- processing.
- Args:
- batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W).
- batch_data_samples (List[:obj:`DetDataSample`]): The batch
- data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
- rescale (bool): Whether to rescale the results.
- Defaults to True.
- Returns:
- list[:obj:`DetDataSample`]: Detection results of the input images.
- Each DetDataSample usually contain 'pred_instances'. And the
- `pred_instances` usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- """
- img_feats = self.extract_feat(batch_inputs)
- head_inputs_dict = self.forward_transformer(img_feats,
- batch_data_samples)
- results_list = self.bbox_head.predict(
- **head_inputs_dict,
- rescale=rescale,
- batch_data_samples=batch_data_samples)
- batch_data_samples = self.add_pred_to_datasample(
- batch_data_samples, results_list)
- return batch_data_samples
- def _forward(
- self,
- batch_inputs: Tensor,
- batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
- """Network forward process. Usually includes backbone, neck and head
- forward without any post-processing.
- Args:
- batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W).
- batch_data_samples (List[:obj:`DetDataSample`], optional): The
- batch data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
- Defaults to None.
- Returns:
- tuple[Tensor]: A tuple of features from ``bbox_head`` forward.
- """
- img_feats = self.extract_feat(batch_inputs)
- head_inputs_dict = self.forward_transformer(img_feats,
- batch_data_samples)
- results = self.bbox_head.forward(**head_inputs_dict)
- return results
- def forward_transformer(self,
- img_feats: Tuple[Tensor],
- batch_data_samples: OptSampleList = None) -> Dict:
- """Forward process of Transformer, which includes four steps:
- 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'. We
- summarized the parameters flow of the existing DETR-like detector,
- which can be illustrated as follow:
- .. code:: text
- img_feats & batch_data_samples
- |
- V
- +-----------------+
- | pre_transformer |
- +-----------------+
- | |
- | V
- | +-----------------+
- | | forward_encoder |
- | +-----------------+
- | |
- | V
- | +---------------+
- | | pre_decoder |
- | +---------------+
- | | |
- V V |
- +-----------------+ |
- | forward_decoder | |
- +-----------------+ |
- | |
- V V
- head_inputs_dict
- Args:
- img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each
- feature map has shape (bs, dim, H, W).
- batch_data_samples (list[:obj:`DetDataSample`], optional): The
- batch data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
- Defaults to None.
- Returns:
- dict: The dictionary of bbox_head function inputs, which always
- includes the `hidden_states` of the decoder output and may contain
- `references` including the initial and intermediate references.
- """
- encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer(
- img_feats, batch_data_samples)
- encoder_outputs_dict = self.forward_encoder(**encoder_inputs_dict)
- tmp_dec_in, head_inputs_dict = self.pre_decoder(**encoder_outputs_dict)
- decoder_inputs_dict.update(tmp_dec_in)
- decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict)
- head_inputs_dict.update(decoder_outputs_dict)
- return head_inputs_dict
- def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]:
- """Extract features.
- Args:
- batch_inputs (Tensor): Image tensor, has shape (bs, dim, H, W).
- Returns:
- tuple[Tensor]: Tuple of feature maps from neck. Each feature map
- has shape (bs, dim, H, W).
- """
- x = self.backbone(batch_inputs)
- if self.with_neck:
- x = self.neck(x)
- return x
- @abstractmethod
- def pre_transformer(
- self,
- img_feats: Tuple[Tensor],
- batch_data_samples: OptSampleList = None) -> Tuple[Dict, Dict]:
- """Process image features before feeding them to the transformer.
- Args:
- img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each
- feature map has shape (bs, dim, H, W).
- batch_data_samples (list[:obj:`DetDataSample`], optional): The
- batch data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
- Defaults to None.
- Returns:
- tuple[dict, dict]: The first dict contains the inputs of encoder
- and the second dict contains the inputs of decoder.
- - encoder_inputs_dict (dict): The keyword args dictionary of
- `self.forward_encoder()`, which includes 'feat', 'feat_mask',
- 'feat_pos', and other algorithm-specific arguments.
- - decoder_inputs_dict (dict): The keyword args dictionary of
- `self.forward_decoder()`, which includes 'memory_mask', and
- other algorithm-specific arguments.
- """
- pass
- @abstractmethod
- def forward_encoder(self, feat: Tensor, feat_mask: Tensor,
- feat_pos: Tensor, **kwargs) -> Dict:
- """Forward with Transformer encoder.
- Args:
- feat (Tensor): Sequential features, has shape (bs, num_feat_points,
- dim).
- feat_mask (Tensor): ByteTensor, the padding mask of the features,
- has shape (bs, num_feat_points).
- feat_pos (Tensor): The positional embeddings of the features, has
- shape (bs, num_feat_points, dim).
- Returns:
- dict: The dictionary of encoder outputs, which includes the
- `memory` of the encoder output and other algorithm-specific
- arguments.
- """
- pass
- @abstractmethod
- def pre_decoder(self, memory: Tensor, **kwargs) -> Tuple[Dict, Dict]:
- """Prepare intermediate variables before entering Transformer decoder,
- such as `query`, `query_pos`, and `reference_points`.
- Args:
- memory (Tensor): The output embeddings of the Transformer encoder,
- has shape (bs, num_feat_points, dim).
- Returns:
- tuple[dict, dict]: The first dict contains the inputs of decoder
- and the second dict contains the inputs of the bbox_head function.
- - decoder_inputs_dict (dict): The keyword dictionary args of
- `self.forward_decoder()`, which includes 'query', 'query_pos',
- 'memory', and other algorithm-specific arguments.
- - head_inputs_dict (dict): The keyword dictionary args of the
- bbox_head functions, which is usually empty, or includes
- `enc_outputs_class` and `enc_outputs_class` when the detector
- support 'two stage' or 'query selection' strategies.
- """
- pass
- @abstractmethod
- def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor,
- **kwargs) -> Dict:
- """Forward with Transformer decoder.
- Args:
- query (Tensor): The queries of decoder inputs, has shape
- (bs, num_queries, dim).
- query_pos (Tensor): The positional queries of decoder inputs,
- has shape (bs, num_queries, dim).
- memory (Tensor): The output embeddings of the Transformer encoder,
- has shape (bs, num_feat_points, dim).
- Returns:
- dict: The dictionary of decoder outputs, which includes the
- `hidden_states` of the decoder output, `references` including
- the initial and intermediate reference_points, and other
- algorithm-specific arguments.
- """
- pass
|