base_mask_head.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from abc import ABCMeta, abstractmethod
  3. from typing import List, Tuple, Union
  4. from mmengine.model import BaseModule
  5. from torch import Tensor
  6. from mmdet.structures import SampleList
  7. from mmdet.utils import InstanceList, OptInstanceList, OptMultiConfig
  8. from ..utils import unpack_gt_instances
  9. class BaseMaskHead(BaseModule, metaclass=ABCMeta):
  10. """Base class for mask heads used in One-Stage Instance Segmentation."""
  11. def __init__(self, init_cfg: OptMultiConfig = None) -> None:
  12. super().__init__(init_cfg=init_cfg)
  13. @abstractmethod
  14. def loss_by_feat(self, *args, **kwargs):
  15. """Calculate the loss based on the features extracted by the mask
  16. head."""
  17. pass
  18. @abstractmethod
  19. def predict_by_feat(self, *args, **kwargs):
  20. """Transform a batch of output features extracted from the head into
  21. mask results."""
  22. pass
  23. def loss(self,
  24. x: Union[List[Tensor], Tuple[Tensor]],
  25. batch_data_samples: SampleList,
  26. positive_infos: OptInstanceList = None,
  27. **kwargs) -> dict:
  28. """Perform forward propagation and loss calculation of the mask head on
  29. the features of the upstream network.
  30. Args:
  31. x (list[Tensor] | tuple[Tensor]): Features from FPN.
  32. Each has a shape (B, C, H, W).
  33. batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
  34. the meta information of each image and corresponding
  35. annotations.
  36. positive_infos (list[:obj:`InstanceData`], optional): Information
  37. of positive samples. Used when the label assignment is
  38. done outside the MaskHead, e.g., BboxHead in
  39. YOLACT or CondInst, etc. When the label assignment is done in
  40. MaskHead, it would be None, like SOLO or SOLOv2. All values
  41. in it should have shape (num_positive_samples, *).
  42. Returns:
  43. dict: A dictionary of loss components.
  44. """
  45. if positive_infos is None:
  46. outs = self(x)
  47. else:
  48. outs = self(x, positive_infos)
  49. assert isinstance(outs, tuple), 'Forward results should be a tuple, ' \
  50. 'even if only one item is returned'
  51. outputs = unpack_gt_instances(batch_data_samples)
  52. batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \
  53. = outputs
  54. for gt_instances, img_metas in zip(batch_gt_instances,
  55. batch_img_metas):
  56. img_shape = img_metas['batch_input_shape']
  57. gt_masks = gt_instances.masks.pad(img_shape)
  58. gt_instances.masks = gt_masks
  59. losses = self.loss_by_feat(
  60. *outs,
  61. batch_gt_instances=batch_gt_instances,
  62. batch_img_metas=batch_img_metas,
  63. positive_infos=positive_infos,
  64. batch_gt_instances_ignore=batch_gt_instances_ignore,
  65. **kwargs)
  66. return losses
  67. def predict(self,
  68. x: Tuple[Tensor],
  69. batch_data_samples: SampleList,
  70. rescale: bool = False,
  71. results_list: OptInstanceList = None,
  72. **kwargs) -> InstanceList:
  73. """Test function without test-time augmentation.
  74. Args:
  75. x (tuple[Tensor]): Multi-level features from the
  76. upstream network, each is a 4D-tensor.
  77. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  78. Samples. It usually includes information such as
  79. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  80. rescale (bool, optional): Whether to rescale the results.
  81. Defaults to False.
  82. results_list (list[obj:`InstanceData`], optional): Detection
  83. results of each image after the post process. Only exist
  84. if there is a `bbox_head`, like `YOLACT`, `CondInst`, etc.
  85. Returns:
  86. list[obj:`InstanceData`]: Instance segmentation
  87. results of each image after the post process.
  88. Each item usually contains following keys.
  89. - scores (Tensor): Classification scores, has a shape
  90. (num_instance,)
  91. - labels (Tensor): Has a shape (num_instances,).
  92. - masks (Tensor): Processed mask results, has a
  93. shape (num_instances, h, w).
  94. """
  95. batch_img_metas = [
  96. data_samples.metainfo for data_samples in batch_data_samples
  97. ]
  98. if results_list is None:
  99. outs = self(x)
  100. else:
  101. outs = self(x, results_list)
  102. results_list = self.predict_by_feat(
  103. *outs,
  104. batch_img_metas=batch_img_metas,
  105. rescale=rescale,
  106. results_list=results_list,
  107. **kwargs)
  108. return results_list