123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206 |
- # Copyright (c) Tianheng Cheng and its affiliates. All Rights Reserved
- from typing import List, Tuple, Union
- import torch
- import torch.nn.functional as F
- from mmengine.structures import InstanceData
- from torch import Tensor
- from mmdet.models import BaseDetector
- from mmdet.models.utils import unpack_gt_instances
- from mmdet.registry import MODELS
- from mmdet.structures import OptSampleList, SampleList
- from mmdet.utils import ConfigType, OptConfigType
- @torch.jit.script
- def rescoring_mask(scores, mask_pred, masks):
- mask_pred_ = mask_pred.float()
- return scores * ((masks * mask_pred_).sum([1, 2]) /
- (mask_pred_.sum([1, 2]) + 1e-6))
- @MODELS.register_module()
- class SparseInst(BaseDetector):
- """Implementation of `SparseInst <https://arxiv.org/abs/1912.02424>`_
- Args:
- data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of
- :class:`DetDataPreprocessor` to process the input data.
- Defaults to None.
- backbone (:obj:`ConfigDict` or dict): The backbone module.
- encoder (:obj:`ConfigDict` or dict): The encoder module.
- decoder (:obj:`ConfigDict` or dict): The decoder module.
- criterion (:obj:`ConfigDict` or dict, optional): The training matcher
- and losses. Defaults to None.
- test_cfg (:obj:`ConfigDict` or dict, optional): The testing config
- of SparseInst. Defaults to None.
- init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
- the initialization. Defaults to None.
- """
- def __init__(self,
- data_preprocessor: ConfigType,
- backbone: ConfigType,
- encoder: ConfigType,
- decoder: ConfigType,
- criterion: OptConfigType = None,
- test_cfg: OptConfigType = None,
- init_cfg: OptConfigType = None):
- super().__init__(
- data_preprocessor=data_preprocessor, init_cfg=init_cfg)
- # backbone
- self.backbone = MODELS.build(backbone)
- # encoder & decoder
- self.encoder = MODELS.build(encoder)
- self.decoder = MODELS.build(decoder)
- # matcher & loss (matcher is built in loss)
- self.criterion = MODELS.build(criterion)
- # inference
- self.cls_threshold = test_cfg.score_thr
- self.mask_threshold = test_cfg.mask_thr_binary
- def _forward(
- self,
- batch_inputs: Tensor,
- batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
- """Network forward process. Usually includes backbone, neck and head
- forward without any post-processing.
- Args:
- batch_inputs (Tensor): Inputs with shape (N, C, H, W).
- Returns:
- tuple[list]: A tuple of features from ``bbox_head`` forward.
- """
- x = self.backbone(batch_inputs)
- x = self.encoder(x)
- results = self.decoder(x)
- return results
- def predict(self,
- batch_inputs: Tensor,
- batch_data_samples: SampleList,
- rescale: bool = True) -> SampleList:
- """Predict results from a batch of inputs and data samples with post-
- processing.
- Args:
- batch_inputs (Tensor): Inputs with shape (N, C, H, W).
- batch_data_samples (List[:obj:`DetDataSample`]): The Data
- Samples. It usually includes information such as
- `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
- rescale (bool): Whether to rescale the results.
- Defaults to True.
- Returns:
- list[:obj:`DetDataSample`]: Detection results of the
- input images. Each DetDataSample usually contain
- 'pred_instances'. And the ``pred_instances`` usually
- contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- """
- max_shape = batch_inputs.shape[-2:]
- output = self._forward(batch_inputs)
- pred_scores = output['pred_logits'].sigmoid()
- pred_masks = output['pred_masks'].sigmoid()
- pred_objectness = output['pred_scores'].sigmoid()
- pred_scores = torch.sqrt(pred_scores * pred_objectness)
- results_list = []
- for batch_idx, (scores_per_image, mask_pred_per_image,
- datasample) in enumerate(
- zip(pred_scores, pred_masks, batch_data_samples)):
- result = InstanceData()
- # max/argmax
- scores, labels = scores_per_image.max(dim=-1)
- # cls threshold
- keep = scores > self.cls_threshold
- scores = scores[keep]
- labels = labels[keep]
- mask_pred_per_image = mask_pred_per_image[keep]
- if scores.size(0) == 0:
- result.scores = scores
- result.labels = labels
- results_list.append(result)
- continue
- img_meta = datasample.metainfo
- # rescoring mask using maskness
- scores = rescoring_mask(scores,
- mask_pred_per_image > self.mask_threshold,
- mask_pred_per_image)
- h, w = img_meta['img_shape'][:2]
- mask_pred_per_image = F.interpolate(
- mask_pred_per_image.unsqueeze(1),
- size=max_shape,
- mode='bilinear',
- align_corners=False)[:, :, :h, :w]
- if rescale:
- ori_h, ori_w = img_meta['ori_shape'][:2]
- mask_pred_per_image = F.interpolate(
- mask_pred_per_image,
- size=(ori_h, ori_w),
- mode='bilinear',
- align_corners=False).squeeze(1)
- mask_pred = mask_pred_per_image > self.mask_threshold
- result.masks = mask_pred
- result.scores = scores
- result.labels = labels
- # create an empty bbox in InstanceData to avoid bugs when
- # calculating metrics.
- result.bboxes = result.scores.new_zeros(len(scores), 4)
- results_list.append(result)
- batch_data_samples = self.add_pred_to_datasample(
- batch_data_samples, results_list)
- return batch_data_samples
- def loss(self, batch_inputs: Tensor,
- batch_data_samples: SampleList) -> Union[dict, list]:
- """Calculate losses from a batch of inputs and data samples.
- Args:
- batch_inputs (Tensor): Input images of shape (N, C, H, W).
- These should usually be mean centered and std scaled.
- batch_data_samples (list[:obj:`DetDataSample`]): The batch
- data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
- Returns:
- dict: A dictionary of loss components.
- """
- outs = self._forward(batch_inputs)
- (batch_gt_instances, batch_gt_instances_ignore,
- batch_img_metas) = unpack_gt_instances(batch_data_samples)
- losses = self.criterion(outs, batch_gt_instances, batch_img_metas,
- batch_gt_instances_ignore)
- return losses
- def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]:
- """Extract features.
- Args:
- batch_inputs (Tensor): Image tensor with shape (N, C, H ,W).
- Returns:
- tuple[Tensor]: Multi-level features that may have
- different resolutions.
- """
- x = self.backbone(batch_inputs)
- x = self.encoder(x)
- return x
|