123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from abc import ABCMeta, abstractmethod
- from typing import List, Tuple, Union
- from mmengine.model import BaseModule
- from torch import Tensor
- from mmdet.structures import SampleList
- from mmdet.utils import InstanceList, OptInstanceList, OptMultiConfig
- from ..utils import unpack_gt_instances
- class BaseMaskHead(BaseModule, metaclass=ABCMeta):
- """Base class for mask heads used in One-Stage Instance Segmentation."""
- def __init__(self, init_cfg: OptMultiConfig = None) -> None:
- super().__init__(init_cfg=init_cfg)
- @abstractmethod
- def loss_by_feat(self, *args, **kwargs):
- """Calculate the loss based on the features extracted by the mask
- head."""
- pass
- @abstractmethod
- def predict_by_feat(self, *args, **kwargs):
- """Transform a batch of output features extracted from the head into
- mask results."""
- pass
- def loss(self,
- x: Union[List[Tensor], Tuple[Tensor]],
- batch_data_samples: SampleList,
- positive_infos: OptInstanceList = None,
- **kwargs) -> dict:
- """Perform forward propagation and loss calculation of the mask head on
- the features of the upstream network.
- Args:
- x (list[Tensor] | tuple[Tensor]): Features from FPN.
- Each has a shape (B, C, H, W).
- batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
- the meta information of each image and corresponding
- annotations.
- positive_infos (list[:obj:`InstanceData`], optional): Information
- of positive samples. Used when the label assignment is
- done outside the MaskHead, e.g., BboxHead in
- YOLACT or CondInst, etc. When the label assignment is done in
- MaskHead, it would be None, like SOLO or SOLOv2. All values
- in it should have shape (num_positive_samples, *).
- Returns:
- dict: A dictionary of loss components.
- """
- if positive_infos is None:
- outs = self(x)
- else:
- outs = self(x, positive_infos)
- assert isinstance(outs, tuple), 'Forward results should be a tuple, ' \
- 'even if only one item is returned'
- outputs = unpack_gt_instances(batch_data_samples)
- batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \
- = outputs
- for gt_instances, img_metas in zip(batch_gt_instances,
- batch_img_metas):
- img_shape = img_metas['batch_input_shape']
- gt_masks = gt_instances.masks.pad(img_shape)
- gt_instances.masks = gt_masks
- losses = self.loss_by_feat(
- *outs,
- batch_gt_instances=batch_gt_instances,
- batch_img_metas=batch_img_metas,
- positive_infos=positive_infos,
- batch_gt_instances_ignore=batch_gt_instances_ignore,
- **kwargs)
- return losses
- def predict(self,
- x: Tuple[Tensor],
- batch_data_samples: SampleList,
- rescale: bool = False,
- results_list: OptInstanceList = None,
- **kwargs) -> InstanceList:
- """Test function without test-time augmentation.
- Args:
- x (tuple[Tensor]): Multi-level features from the
- upstream network, each is a 4D-tensor.
- batch_data_samples (List[:obj:`DetDataSample`]): The Data
- Samples. It usually includes information such as
- `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
- rescale (bool, optional): Whether to rescale the results.
- Defaults to False.
- results_list (list[obj:`InstanceData`], optional): Detection
- results of each image after the post process. Only exist
- if there is a `bbox_head`, like `YOLACT`, `CondInst`, etc.
- Returns:
- list[obj:`InstanceData`]: Instance segmentation
- results of each image after the post process.
- Each item usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance,)
- - labels (Tensor): Has a shape (num_instances,).
- - masks (Tensor): Processed mask results, has a
- shape (num_instances, h, w).
- """
- batch_img_metas = [
- data_samples.metainfo for data_samples in batch_data_samples
- ]
- if results_list is None:
- outs = self(x)
- else:
- outs = self(x, results_list)
- results_list = self.predict_by_feat(
- *outs,
- batch_img_metas=batch_img_metas,
- rescale=rescale,
- results_list=results_list,
- **kwargs)
- return results_list
|