123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Dict, List, Tuple
- from torch import Tensor
- from mmdet.registry import MODELS
- from mmdet.structures import SampleList
- from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
- from .single_stage import SingleStageDetector
- @MODELS.register_module()
- class MaskFormer(SingleStageDetector):
- r"""Implementation of `Per-Pixel Classification is
- NOT All You Need for Semantic Segmentation
- <https://arxiv.org/pdf/2107.06278>`_."""
- def __init__(self,
- backbone: ConfigType,
- neck: OptConfigType = None,
- panoptic_head: OptConfigType = None,
- panoptic_fusion_head: OptConfigType = None,
- train_cfg: OptConfigType = None,
- test_cfg: OptConfigType = None,
- data_preprocessor: OptConfigType = None,
- init_cfg: OptMultiConfig = None):
- super(SingleStageDetector, self).__init__(
- data_preprocessor=data_preprocessor, init_cfg=init_cfg)
- self.backbone = MODELS.build(backbone)
- if neck is not None:
- self.neck = MODELS.build(neck)
- panoptic_head_ = panoptic_head.deepcopy()
- panoptic_head_.update(train_cfg=train_cfg)
- panoptic_head_.update(test_cfg=test_cfg)
- self.panoptic_head = MODELS.build(panoptic_head_)
- panoptic_fusion_head_ = panoptic_fusion_head.deepcopy()
- panoptic_fusion_head_.update(test_cfg=test_cfg)
- self.panoptic_fusion_head = MODELS.build(panoptic_fusion_head_)
- self.num_things_classes = self.panoptic_head.num_things_classes
- self.num_stuff_classes = self.panoptic_head.num_stuff_classes
- self.num_classes = self.panoptic_head.num_classes
- self.train_cfg = train_cfg
- self.test_cfg = test_cfg
- def loss(self, batch_inputs: Tensor,
- batch_data_samples: SampleList) -> Dict[str, Tensor]:
- """
- Args:
- batch_inputs (Tensor): Input images of shape (N, C, H, W).
- These should usually be mean centered and std scaled.
- batch_data_samples (list[:obj:`DetDataSample`]): The batch
- data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
- Returns:
- dict[str, Tensor]: a dictionary of loss components
- """
- x = self.extract_feat(batch_inputs)
- losses = self.panoptic_head.loss(x, batch_data_samples)
- return losses
- def predict(self,
- batch_inputs: Tensor,
- batch_data_samples: SampleList,
- rescale: bool = True) -> SampleList:
- """Predict results from a batch of inputs and data samples with post-
- processing.
- Args:
- batch_inputs (Tensor): Inputs with shape (N, C, H, W).
- batch_data_samples (List[:obj:`DetDataSample`]): The Data
- Samples. It usually includes information such as
- `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
- rescale (bool): Whether to rescale the results.
- Defaults to True.
- Returns:
- list[:obj:`DetDataSample`]: Detection results of the
- input images. Each DetDataSample usually contain
- 'pred_instances' and `pred_panoptic_seg`. And the
- ``pred_instances`` usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- - masks (Tensor): Has a shape (num_instances, H, W).
- And the ``pred_panoptic_seg`` contains the following key
- - sem_seg (Tensor): panoptic segmentation mask, has a
- shape (1, h, w).
- """
- feats = self.extract_feat(batch_inputs)
- mask_cls_results, mask_pred_results = self.panoptic_head.predict(
- feats, batch_data_samples)
- results_list = self.panoptic_fusion_head.predict(
- mask_cls_results,
- mask_pred_results,
- batch_data_samples,
- rescale=rescale)
- results = self.add_pred_to_datasample(batch_data_samples, results_list)
- return results
- def add_pred_to_datasample(self, data_samples: SampleList,
- results_list: List[dict]) -> SampleList:
- """Add predictions to `DetDataSample`.
- Args:
- data_samples (list[:obj:`DetDataSample`], optional): A batch of
- data samples that contain annotations and predictions.
- results_list (List[dict]): Instance segmentation, segmantic
- segmentation and panoptic segmentation results.
- Returns:
- list[:obj:`DetDataSample`]: Detection results of the
- input images. Each DetDataSample usually contain
- 'pred_instances' and `pred_panoptic_seg`. And the
- ``pred_instances`` usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- - masks (Tensor): Has a shape (num_instances, H, W).
- And the ``pred_panoptic_seg`` contains the following key
- - sem_seg (Tensor): panoptic segmentation mask, has a
- shape (1, h, w).
- """
- for data_sample, pred_results in zip(data_samples, results_list):
- if 'pan_results' in pred_results:
- data_sample.pred_panoptic_seg = pred_results['pan_results']
- if 'ins_results' in pred_results:
- data_sample.pred_instances = pred_results['ins_results']
- assert 'sem_results' not in pred_results, 'segmantic ' \
- 'segmentation results are not supported yet.'
- return data_samples
- def _forward(self, batch_inputs: Tensor,
- batch_data_samples: SampleList) -> Tuple[List[Tensor]]:
- """Network forward process. Usually includes backbone, neck and head
- forward without any post-processing.
- Args:
- batch_inputs (Tensor): Inputs with shape (N, C, H, W).
- batch_data_samples (list[:obj:`DetDataSample`]): The batch
- data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
- Returns:
- tuple[List[Tensor]]: A tuple of features from ``panoptic_head``
- forward.
- """
- feats = self.extract_feat(batch_inputs)
- results = self.panoptic_head.forward(feats, batch_data_samples)
- return results
|