base_roi_head.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from abc import ABCMeta, abstractmethod
  3. from typing import Tuple
  4. from mmengine.model import BaseModule
  5. from torch import Tensor
  6. from mmdet.registry import MODELS
  7. from mmdet.structures import SampleList
  8. from mmdet.utils import InstanceList, OptConfigType, OptMultiConfig
  9. class BaseRoIHead(BaseModule, metaclass=ABCMeta):
  10. """Base class for RoIHeads."""
  11. def __init__(self,
  12. bbox_roi_extractor: OptMultiConfig = None,
  13. bbox_head: OptMultiConfig = None,
  14. mask_roi_extractor: OptMultiConfig = None,
  15. mask_head: OptMultiConfig = None,
  16. shared_head: OptConfigType = None,
  17. train_cfg: OptConfigType = None,
  18. test_cfg: OptConfigType = None,
  19. init_cfg: OptMultiConfig = None) -> None:
  20. super().__init__(init_cfg=init_cfg)
  21. self.train_cfg = train_cfg
  22. self.test_cfg = test_cfg
  23. if shared_head is not None:
  24. self.shared_head = MODELS.build(shared_head)
  25. if bbox_head is not None:
  26. self.init_bbox_head(bbox_roi_extractor, bbox_head)
  27. if mask_head is not None:
  28. self.init_mask_head(mask_roi_extractor, mask_head)
  29. self.init_assigner_sampler()
  30. @property
  31. def with_bbox(self) -> bool:
  32. """bool: whether the RoI head contains a `bbox_head`"""
  33. return hasattr(self, 'bbox_head') and self.bbox_head is not None
  34. @property
  35. def with_mask(self) -> bool:
  36. """bool: whether the RoI head contains a `mask_head`"""
  37. return hasattr(self, 'mask_head') and self.mask_head is not None
  38. @property
  39. def with_shared_head(self) -> bool:
  40. """bool: whether the RoI head contains a `shared_head`"""
  41. return hasattr(self, 'shared_head') and self.shared_head is not None
  42. @abstractmethod
  43. def init_bbox_head(self, *args, **kwargs):
  44. """Initialize ``bbox_head``"""
  45. pass
  46. @abstractmethod
  47. def init_mask_head(self, *args, **kwargs):
  48. """Initialize ``mask_head``"""
  49. pass
  50. @abstractmethod
  51. def init_assigner_sampler(self, *args, **kwargs):
  52. """Initialize assigner and sampler."""
  53. pass
  54. @abstractmethod
  55. def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
  56. batch_data_samples: SampleList):
  57. """Perform forward propagation and loss calculation of the roi head on
  58. the features of the upstream network."""
  59. def predict(self,
  60. x: Tuple[Tensor],
  61. rpn_results_list: InstanceList,
  62. batch_data_samples: SampleList,
  63. rescale: bool = False) -> InstanceList:
  64. """Perform forward propagation of the roi head and predict detection
  65. results on the features of the upstream network.
  66. Args:
  67. x (tuple[Tensor]): Features from upstream network. Each
  68. has shape (N, C, H, W).
  69. rpn_results_list (list[:obj:`InstanceData`]): list of region
  70. proposals.
  71. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  72. Samples. It usually includes information such as
  73. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  74. rescale (bool): Whether to rescale the results to
  75. the original image. Defaults to True.
  76. Returns:
  77. list[obj:`InstanceData`]: Detection results of each image.
  78. Each item usually contains following keys.
  79. - scores (Tensor): Classification scores, has a shape
  80. (num_instance, )
  81. - labels (Tensor): Labels of bboxes, has a shape
  82. (num_instances, ).
  83. - bboxes (Tensor): Has a shape (num_instances, 4),
  84. the last dimension 4 arrange as (x1, y1, x2, y2).
  85. - masks (Tensor): Has a shape (num_instances, H, W).
  86. """
  87. assert self.with_bbox, 'Bbox head must be implemented.'
  88. batch_img_metas = [
  89. data_samples.metainfo for data_samples in batch_data_samples
  90. ]
  91. # TODO: nms_op in mmcv need be enhanced, the bbox result may get
  92. # difference when not rescale in bbox_head
  93. # If it has the mask branch, the bbox branch does not need
  94. # to be scaled to the original image scale, because the mask
  95. # branch will scale both bbox and mask at the same time.
  96. bbox_rescale = rescale if not self.with_mask else False
  97. results_list = self.predict_bbox(
  98. x,
  99. batch_img_metas,
  100. rpn_results_list,
  101. rcnn_test_cfg=self.test_cfg,
  102. rescale=bbox_rescale)
  103. if self.with_mask:
  104. results_list = self.predict_mask(
  105. x, batch_img_metas, results_list, rescale=rescale)
  106. return results_list