base_semantic_head.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from abc import ABCMeta, abstractmethod
  3. from typing import Dict, List, Tuple, Union
  4. import torch.nn.functional as F
  5. from mmengine.model import BaseModule
  6. from torch import Tensor
  7. from mmdet.registry import MODELS
  8. from mmdet.structures import SampleList
  9. from mmdet.utils import ConfigType, OptMultiConfig
  10. @MODELS.register_module()
  11. class BaseSemanticHead(BaseModule, metaclass=ABCMeta):
  12. """Base module of Semantic Head.
  13. Args:
  14. num_classes (int): the number of classes.
  15. seg_rescale_factor (float): the rescale factor for ``gt_sem_seg``,
  16. which equals to ``1 / output_strides``. The output_strides is
  17. for ``seg_preds``. Defaults to 1 / 4.
  18. init_cfg (Optional[Union[:obj:`ConfigDict`, dict]]): the initialization
  19. config.
  20. loss_seg (Union[:obj:`ConfigDict`, dict]): the loss of the semantic
  21. head.
  22. """
  23. def __init__(self,
  24. num_classes: int,
  25. seg_rescale_factor: float = 1 / 4.,
  26. loss_seg: ConfigType = dict(
  27. type='CrossEntropyLoss',
  28. ignore_index=255,
  29. loss_weight=1.0),
  30. init_cfg: OptMultiConfig = None) -> None:
  31. super().__init__(init_cfg=init_cfg)
  32. self.loss_seg = MODELS.build(loss_seg)
  33. self.num_classes = num_classes
  34. self.seg_rescale_factor = seg_rescale_factor
  35. @abstractmethod
  36. def forward(self, x: Union[Tensor, Tuple[Tensor]]) -> Dict[str, Tensor]:
  37. """Placeholder of forward function.
  38. Args:
  39. x (Tensor): Feature maps.
  40. Returns:
  41. Dict[str, Tensor]: A dictionary, including features
  42. and predicted scores. Required keys: 'seg_preds'
  43. and 'feats'.
  44. """
  45. pass
  46. @abstractmethod
  47. def loss(self, x: Union[Tensor, Tuple[Tensor]],
  48. batch_data_samples: SampleList) -> Dict[str, Tensor]:
  49. """
  50. Args:
  51. x (Union[Tensor, Tuple[Tensor]]): Feature maps.
  52. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  53. data samples. It usually includes information such
  54. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  55. Args:
  56. x (Tensor): Feature maps.
  57. Returns:
  58. Dict[str, Tensor]: The loss of semantic head.
  59. """
  60. pass
  61. def predict(self,
  62. x: Union[Tensor, Tuple[Tensor]],
  63. batch_img_metas: List[dict],
  64. rescale: bool = False) -> List[Tensor]:
  65. """Test without Augmentation.
  66. Args:
  67. x (Union[Tensor, Tuple[Tensor]]): Feature maps.
  68. batch_img_metas (List[dict]): List of image information.
  69. rescale (bool): Whether to rescale the results.
  70. Defaults to False.
  71. Returns:
  72. list[Tensor]: semantic segmentation logits.
  73. """
  74. seg_preds = self.forward(x)['seg_preds']
  75. seg_preds = F.interpolate(
  76. seg_preds,
  77. size=batch_img_metas[0]['batch_input_shape'],
  78. mode='bilinear',
  79. align_corners=False)
  80. seg_preds = [seg_preds[i] for i in range(len(batch_img_metas))]
  81. if rescale:
  82. seg_pred_list = []
  83. for i in range(len(batch_img_metas)):
  84. h, w = batch_img_metas[i]['img_shape']
  85. seg_pred = seg_preds[i][:, :h, :w]
  86. h, w = batch_img_metas[i]['ori_shape']
  87. seg_pred = F.interpolate(
  88. seg_pred[None],
  89. size=(h, w),
  90. mode='bilinear',
  91. align_corners=False)[0]
  92. seg_pred_list.append(seg_pred)
  93. else:
  94. seg_pred_list = seg_preds
  95. return seg_pred_list