maskformer.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, List, Tuple
  3. from torch import Tensor
  4. from mmdet.registry import MODELS
  5. from mmdet.structures import SampleList
  6. from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
  7. from .single_stage import SingleStageDetector
  8. @MODELS.register_module()
  9. class MaskFormer(SingleStageDetector):
  10. r"""Implementation of `Per-Pixel Classification is
  11. NOT All You Need for Semantic Segmentation
  12. <https://arxiv.org/pdf/2107.06278>`_."""
  13. def __init__(self,
  14. backbone: ConfigType,
  15. neck: OptConfigType = None,
  16. panoptic_head: OptConfigType = None,
  17. panoptic_fusion_head: OptConfigType = None,
  18. train_cfg: OptConfigType = None,
  19. test_cfg: OptConfigType = None,
  20. data_preprocessor: OptConfigType = None,
  21. init_cfg: OptMultiConfig = None):
  22. super(SingleStageDetector, self).__init__(
  23. data_preprocessor=data_preprocessor, init_cfg=init_cfg)
  24. self.backbone = MODELS.build(backbone)
  25. if neck is not None:
  26. self.neck = MODELS.build(neck)
  27. panoptic_head_ = panoptic_head.deepcopy()
  28. panoptic_head_.update(train_cfg=train_cfg)
  29. panoptic_head_.update(test_cfg=test_cfg)
  30. self.panoptic_head = MODELS.build(panoptic_head_)
  31. panoptic_fusion_head_ = panoptic_fusion_head.deepcopy()
  32. panoptic_fusion_head_.update(test_cfg=test_cfg)
  33. self.panoptic_fusion_head = MODELS.build(panoptic_fusion_head_)
  34. self.num_things_classes = self.panoptic_head.num_things_classes
  35. self.num_stuff_classes = self.panoptic_head.num_stuff_classes
  36. self.num_classes = self.panoptic_head.num_classes
  37. self.train_cfg = train_cfg
  38. self.test_cfg = test_cfg
  39. def loss(self, batch_inputs: Tensor,
  40. batch_data_samples: SampleList) -> Dict[str, Tensor]:
  41. """
  42. Args:
  43. batch_inputs (Tensor): Input images of shape (N, C, H, W).
  44. These should usually be mean centered and std scaled.
  45. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  46. data samples. It usually includes information such
  47. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  48. Returns:
  49. dict[str, Tensor]: a dictionary of loss components
  50. """
  51. x = self.extract_feat(batch_inputs)
  52. losses = self.panoptic_head.loss(x, batch_data_samples)
  53. return losses
  54. def predict(self,
  55. batch_inputs: Tensor,
  56. batch_data_samples: SampleList,
  57. rescale: bool = True) -> SampleList:
  58. """Predict results from a batch of inputs and data samples with post-
  59. processing.
  60. Args:
  61. batch_inputs (Tensor): Inputs with shape (N, C, H, W).
  62. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  63. Samples. It usually includes information such as
  64. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  65. rescale (bool): Whether to rescale the results.
  66. Defaults to True.
  67. Returns:
  68. list[:obj:`DetDataSample`]: Detection results of the
  69. input images. Each DetDataSample usually contain
  70. 'pred_instances' and `pred_panoptic_seg`. And the
  71. ``pred_instances`` usually contains following keys.
  72. - scores (Tensor): Classification scores, has a shape
  73. (num_instance, )
  74. - labels (Tensor): Labels of bboxes, has a shape
  75. (num_instances, ).
  76. - bboxes (Tensor): Has a shape (num_instances, 4),
  77. the last dimension 4 arrange as (x1, y1, x2, y2).
  78. - masks (Tensor): Has a shape (num_instances, H, W).
  79. And the ``pred_panoptic_seg`` contains the following key
  80. - sem_seg (Tensor): panoptic segmentation mask, has a
  81. shape (1, h, w).
  82. """
  83. feats = self.extract_feat(batch_inputs)
  84. mask_cls_results, mask_pred_results = self.panoptic_head.predict(
  85. feats, batch_data_samples)
  86. results_list = self.panoptic_fusion_head.predict(
  87. mask_cls_results,
  88. mask_pred_results,
  89. batch_data_samples,
  90. rescale=rescale)
  91. results = self.add_pred_to_datasample(batch_data_samples, results_list)
  92. return results
  93. def add_pred_to_datasample(self, data_samples: SampleList,
  94. results_list: List[dict]) -> SampleList:
  95. """Add predictions to `DetDataSample`.
  96. Args:
  97. data_samples (list[:obj:`DetDataSample`], optional): A batch of
  98. data samples that contain annotations and predictions.
  99. results_list (List[dict]): Instance segmentation, segmantic
  100. segmentation and panoptic segmentation results.
  101. Returns:
  102. list[:obj:`DetDataSample`]: Detection results of the
  103. input images. Each DetDataSample usually contain
  104. 'pred_instances' and `pred_panoptic_seg`. And the
  105. ``pred_instances`` usually contains following keys.
  106. - scores (Tensor): Classification scores, has a shape
  107. (num_instance, )
  108. - labels (Tensor): Labels of bboxes, has a shape
  109. (num_instances, ).
  110. - bboxes (Tensor): Has a shape (num_instances, 4),
  111. the last dimension 4 arrange as (x1, y1, x2, y2).
  112. - masks (Tensor): Has a shape (num_instances, H, W).
  113. And the ``pred_panoptic_seg`` contains the following key
  114. - sem_seg (Tensor): panoptic segmentation mask, has a
  115. shape (1, h, w).
  116. """
  117. for data_sample, pred_results in zip(data_samples, results_list):
  118. if 'pan_results' in pred_results:
  119. data_sample.pred_panoptic_seg = pred_results['pan_results']
  120. if 'ins_results' in pred_results:
  121. data_sample.pred_instances = pred_results['ins_results']
  122. assert 'sem_results' not in pred_results, 'segmantic ' \
  123. 'segmentation results are not supported yet.'
  124. return data_samples
  125. def _forward(self, batch_inputs: Tensor,
  126. batch_data_samples: SampleList) -> Tuple[List[Tensor]]:
  127. """Network forward process. Usually includes backbone, neck and head
  128. forward without any post-processing.
  129. Args:
  130. batch_inputs (Tensor): Inputs with shape (N, C, H, W).
  131. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  132. data samples. It usually includes information such
  133. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  134. Returns:
  135. tuple[List[Tensor]]: A tuple of features from ``panoptic_head``
  136. forward.
  137. """
  138. feats = self.extract_feat(batch_inputs)
  139. results = self.panoptic_head.forward(feats, batch_data_samples)
  140. return results