sparseinst.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # Copyright (c) Tianheng Cheng and its affiliates. All Rights Reserved
  2. from typing import List, Tuple, Union
  3. import torch
  4. import torch.nn.functional as F
  5. from mmengine.structures import InstanceData
  6. from torch import Tensor
  7. from mmdet.models import BaseDetector
  8. from mmdet.models.utils import unpack_gt_instances
  9. from mmdet.registry import MODELS
  10. from mmdet.structures import OptSampleList, SampleList
  11. from mmdet.utils import ConfigType, OptConfigType
  12. @torch.jit.script
  13. def rescoring_mask(scores, mask_pred, masks):
  14. mask_pred_ = mask_pred.float()
  15. return scores * ((masks * mask_pred_).sum([1, 2]) /
  16. (mask_pred_.sum([1, 2]) + 1e-6))
  17. @MODELS.register_module()
  18. class SparseInst(BaseDetector):
  19. """Implementation of `SparseInst <https://arxiv.org/abs/1912.02424>`_
  20. Args:
  21. data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of
  22. :class:`DetDataPreprocessor` to process the input data.
  23. Defaults to None.
  24. backbone (:obj:`ConfigDict` or dict): The backbone module.
  25. encoder (:obj:`ConfigDict` or dict): The encoder module.
  26. decoder (:obj:`ConfigDict` or dict): The decoder module.
  27. criterion (:obj:`ConfigDict` or dict, optional): The training matcher
  28. and losses. Defaults to None.
  29. test_cfg (:obj:`ConfigDict` or dict, optional): The testing config
  30. of SparseInst. Defaults to None.
  31. init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
  32. the initialization. Defaults to None.
  33. """
  34. def __init__(self,
  35. data_preprocessor: ConfigType,
  36. backbone: ConfigType,
  37. encoder: ConfigType,
  38. decoder: ConfigType,
  39. criterion: OptConfigType = None,
  40. test_cfg: OptConfigType = None,
  41. init_cfg: OptConfigType = None):
  42. super().__init__(
  43. data_preprocessor=data_preprocessor, init_cfg=init_cfg)
  44. # backbone
  45. self.backbone = MODELS.build(backbone)
  46. # encoder & decoder
  47. self.encoder = MODELS.build(encoder)
  48. self.decoder = MODELS.build(decoder)
  49. # matcher & loss (matcher is built in loss)
  50. self.criterion = MODELS.build(criterion)
  51. # inference
  52. self.cls_threshold = test_cfg.score_thr
  53. self.mask_threshold = test_cfg.mask_thr_binary
  54. def _forward(
  55. self,
  56. batch_inputs: Tensor,
  57. batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
  58. """Network forward process. Usually includes backbone, neck and head
  59. forward without any post-processing.
  60. Args:
  61. batch_inputs (Tensor): Inputs with shape (N, C, H, W).
  62. Returns:
  63. tuple[list]: A tuple of features from ``bbox_head`` forward.
  64. """
  65. x = self.backbone(batch_inputs)
  66. x = self.encoder(x)
  67. results = self.decoder(x)
  68. return results
  69. def predict(self,
  70. batch_inputs: Tensor,
  71. batch_data_samples: SampleList,
  72. rescale: bool = True) -> SampleList:
  73. """Predict results from a batch of inputs and data samples with post-
  74. processing.
  75. Args:
  76. batch_inputs (Tensor): Inputs with shape (N, C, H, W).
  77. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  78. Samples. It usually includes information such as
  79. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  80. rescale (bool): Whether to rescale the results.
  81. Defaults to True.
  82. Returns:
  83. list[:obj:`DetDataSample`]: Detection results of the
  84. input images. Each DetDataSample usually contain
  85. 'pred_instances'. And the ``pred_instances`` usually
  86. contains following keys.
  87. - scores (Tensor): Classification scores, has a shape
  88. (num_instance, )
  89. - labels (Tensor): Labels of bboxes, has a shape
  90. (num_instances, ).
  91. - bboxes (Tensor): Has a shape (num_instances, 4),
  92. the last dimension 4 arrange as (x1, y1, x2, y2).
  93. """
  94. max_shape = batch_inputs.shape[-2:]
  95. output = self._forward(batch_inputs)
  96. pred_scores = output['pred_logits'].sigmoid()
  97. pred_masks = output['pred_masks'].sigmoid()
  98. pred_objectness = output['pred_scores'].sigmoid()
  99. pred_scores = torch.sqrt(pred_scores * pred_objectness)
  100. results_list = []
  101. for batch_idx, (scores_per_image, mask_pred_per_image,
  102. datasample) in enumerate(
  103. zip(pred_scores, pred_masks, batch_data_samples)):
  104. result = InstanceData()
  105. # max/argmax
  106. scores, labels = scores_per_image.max(dim=-1)
  107. # cls threshold
  108. keep = scores > self.cls_threshold
  109. scores = scores[keep]
  110. labels = labels[keep]
  111. mask_pred_per_image = mask_pred_per_image[keep]
  112. if scores.size(0) == 0:
  113. result.scores = scores
  114. result.labels = labels
  115. results_list.append(result)
  116. continue
  117. img_meta = datasample.metainfo
  118. # rescoring mask using maskness
  119. scores = rescoring_mask(scores,
  120. mask_pred_per_image > self.mask_threshold,
  121. mask_pred_per_image)
  122. h, w = img_meta['img_shape'][:2]
  123. mask_pred_per_image = F.interpolate(
  124. mask_pred_per_image.unsqueeze(1),
  125. size=max_shape,
  126. mode='bilinear',
  127. align_corners=False)[:, :, :h, :w]
  128. if rescale:
  129. ori_h, ori_w = img_meta['ori_shape'][:2]
  130. mask_pred_per_image = F.interpolate(
  131. mask_pred_per_image,
  132. size=(ori_h, ori_w),
  133. mode='bilinear',
  134. align_corners=False).squeeze(1)
  135. mask_pred = mask_pred_per_image > self.mask_threshold
  136. result.masks = mask_pred
  137. result.scores = scores
  138. result.labels = labels
  139. # create an empty bbox in InstanceData to avoid bugs when
  140. # calculating metrics.
  141. result.bboxes = result.scores.new_zeros(len(scores), 4)
  142. results_list.append(result)
  143. batch_data_samples = self.add_pred_to_datasample(
  144. batch_data_samples, results_list)
  145. return batch_data_samples
  146. def loss(self, batch_inputs: Tensor,
  147. batch_data_samples: SampleList) -> Union[dict, list]:
  148. """Calculate losses from a batch of inputs and data samples.
  149. Args:
  150. batch_inputs (Tensor): Input images of shape (N, C, H, W).
  151. These should usually be mean centered and std scaled.
  152. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  153. data samples. It usually includes information such
  154. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  155. Returns:
  156. dict: A dictionary of loss components.
  157. """
  158. outs = self._forward(batch_inputs)
  159. (batch_gt_instances, batch_gt_instances_ignore,
  160. batch_img_metas) = unpack_gt_instances(batch_data_samples)
  161. losses = self.criterion(outs, batch_gt_instances, batch_img_metas,
  162. batch_gt_instances_ignore)
  163. return losses
  164. def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]:
  165. """Extract features.
  166. Args:
  167. batch_inputs (Tensor): Image tensor with shape (N, C, H ,W).
  168. Returns:
  169. tuple[Tensor]: Multi-level features that may have
  170. different resolutions.
  171. """
  172. x = self.backbone(batch_inputs)
  173. x = self.encoder(x)
  174. return x