fused_semantic_head.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. from typing import Tuple
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmcv.cnn import ConvModule
  7. from mmengine.config import ConfigDict
  8. from mmengine.model import BaseModule
  9. from torch import Tensor
  10. from mmdet.registry import MODELS
  11. from mmdet.utils import MultiConfig, OptConfigType
  12. @MODELS.register_module()
  13. class FusedSemanticHead(BaseModule):
  14. r"""Multi-level fused semantic segmentation head.
  15. .. code-block:: none
  16. in_1 -> 1x1 conv ---
  17. |
  18. in_2 -> 1x1 conv -- |
  19. ||
  20. in_3 -> 1x1 conv - ||
  21. ||| /-> 1x1 conv (mask prediction)
  22. in_4 -> 1x1 conv -----> 3x3 convs (*4)
  23. | \-> 1x1 conv (feature)
  24. in_5 -> 1x1 conv ---
  25. """ # noqa: W605
  26. def __init__(
  27. self,
  28. num_ins: int,
  29. fusion_level: int,
  30. seg_scale_factor=1 / 8,
  31. num_convs: int = 4,
  32. in_channels: int = 256,
  33. conv_out_channels: int = 256,
  34. num_classes: int = 183,
  35. conv_cfg: OptConfigType = None,
  36. norm_cfg: OptConfigType = None,
  37. ignore_label: int = None,
  38. loss_weight: float = None,
  39. loss_seg: ConfigDict = dict(
  40. type='CrossEntropyLoss', ignore_index=255, loss_weight=0.2),
  41. init_cfg: MultiConfig = dict(
  42. type='Kaiming', override=dict(name='conv_logits'))
  43. ) -> None:
  44. super().__init__(init_cfg=init_cfg)
  45. self.num_ins = num_ins
  46. self.fusion_level = fusion_level
  47. self.seg_scale_factor = seg_scale_factor
  48. self.num_convs = num_convs
  49. self.in_channels = in_channels
  50. self.conv_out_channels = conv_out_channels
  51. self.num_classes = num_classes
  52. self.conv_cfg = conv_cfg
  53. self.norm_cfg = norm_cfg
  54. self.fp16_enabled = False
  55. self.lateral_convs = nn.ModuleList()
  56. for i in range(self.num_ins):
  57. self.lateral_convs.append(
  58. ConvModule(
  59. self.in_channels,
  60. self.in_channels,
  61. 1,
  62. conv_cfg=self.conv_cfg,
  63. norm_cfg=self.norm_cfg,
  64. inplace=False))
  65. self.convs = nn.ModuleList()
  66. for i in range(self.num_convs):
  67. in_channels = self.in_channels if i == 0 else conv_out_channels
  68. self.convs.append(
  69. ConvModule(
  70. in_channels,
  71. conv_out_channels,
  72. 3,
  73. padding=1,
  74. conv_cfg=self.conv_cfg,
  75. norm_cfg=self.norm_cfg))
  76. self.conv_embedding = ConvModule(
  77. conv_out_channels,
  78. conv_out_channels,
  79. 1,
  80. conv_cfg=self.conv_cfg,
  81. norm_cfg=self.norm_cfg)
  82. self.conv_logits = nn.Conv2d(conv_out_channels, self.num_classes, 1)
  83. if ignore_label:
  84. loss_seg['ignore_index'] = ignore_label
  85. if loss_weight:
  86. loss_seg['loss_weight'] = loss_weight
  87. if ignore_label or loss_weight:
  88. warnings.warn('``ignore_label`` and ``loss_weight`` would be '
  89. 'deprecated soon. Please set ``ingore_index`` and '
  90. '``loss_weight`` in ``loss_seg`` instead.')
  91. self.criterion = MODELS.build(loss_seg)
  92. def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor]:
  93. """Forward function.
  94. Args:
  95. feats (tuple[Tensor]): Multi scale feature maps.
  96. Returns:
  97. tuple[Tensor]:
  98. - mask_preds (Tensor): Predicted mask logits.
  99. - x (Tensor): Fused feature.
  100. """
  101. x = self.lateral_convs[self.fusion_level](feats[self.fusion_level])
  102. fused_size = tuple(x.shape[-2:])
  103. for i, feat in enumerate(feats):
  104. if i != self.fusion_level:
  105. feat = F.interpolate(
  106. feat, size=fused_size, mode='bilinear', align_corners=True)
  107. # fix runtime error of "+=" inplace operation in PyTorch 1.10
  108. x = x + self.lateral_convs[i](feat)
  109. for i in range(self.num_convs):
  110. x = self.convs[i](x)
  111. mask_preds = self.conv_logits(x)
  112. x = self.conv_embedding(x)
  113. return mask_preds, x
  114. def loss(self, mask_preds: Tensor, labels: Tensor) -> Tensor:
  115. """Loss function.
  116. Args:
  117. mask_preds (Tensor): Predicted mask logits.
  118. labels (Tensor): Ground truth.
  119. Returns:
  120. Tensor: Semantic segmentation loss.
  121. """
  122. labels = F.interpolate(
  123. labels.float(), scale_factor=self.seg_scale_factor, mode='nearest')
  124. labels = labels.squeeze(1).long()
  125. loss_semantic_seg = self.criterion(mask_preds, labels)
  126. return loss_semantic_seg