123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import warnings
- from typing import Tuple
- import torch.nn as nn
- import torch.nn.functional as F
- from mmcv.cnn import ConvModule
- from mmengine.config import ConfigDict
- from mmengine.model import BaseModule
- from torch import Tensor
- from mmdet.registry import MODELS
- from mmdet.utils import MultiConfig, OptConfigType
- @MODELS.register_module()
- class FusedSemanticHead(BaseModule):
- r"""Multi-level fused semantic segmentation head.
- .. code-block:: none
- in_1 -> 1x1 conv ---
- |
- in_2 -> 1x1 conv -- |
- ||
- in_3 -> 1x1 conv - ||
- ||| /-> 1x1 conv (mask prediction)
- in_4 -> 1x1 conv -----> 3x3 convs (*4)
- | \-> 1x1 conv (feature)
- in_5 -> 1x1 conv ---
- """ # noqa: W605
- def __init__(
- self,
- num_ins: int,
- fusion_level: int,
- seg_scale_factor=1 / 8,
- num_convs: int = 4,
- in_channels: int = 256,
- conv_out_channels: int = 256,
- num_classes: int = 183,
- conv_cfg: OptConfigType = None,
- norm_cfg: OptConfigType = None,
- ignore_label: int = None,
- loss_weight: float = None,
- loss_seg: ConfigDict = dict(
- type='CrossEntropyLoss', ignore_index=255, loss_weight=0.2),
- init_cfg: MultiConfig = dict(
- type='Kaiming', override=dict(name='conv_logits'))
- ) -> None:
- super().__init__(init_cfg=init_cfg)
- self.num_ins = num_ins
- self.fusion_level = fusion_level
- self.seg_scale_factor = seg_scale_factor
- self.num_convs = num_convs
- self.in_channels = in_channels
- self.conv_out_channels = conv_out_channels
- self.num_classes = num_classes
- self.conv_cfg = conv_cfg
- self.norm_cfg = norm_cfg
- self.fp16_enabled = False
- self.lateral_convs = nn.ModuleList()
- for i in range(self.num_ins):
- self.lateral_convs.append(
- ConvModule(
- self.in_channels,
- self.in_channels,
- 1,
- conv_cfg=self.conv_cfg,
- norm_cfg=self.norm_cfg,
- inplace=False))
- self.convs = nn.ModuleList()
- for i in range(self.num_convs):
- in_channels = self.in_channels if i == 0 else conv_out_channels
- self.convs.append(
- ConvModule(
- in_channels,
- conv_out_channels,
- 3,
- padding=1,
- conv_cfg=self.conv_cfg,
- norm_cfg=self.norm_cfg))
- self.conv_embedding = ConvModule(
- conv_out_channels,
- conv_out_channels,
- 1,
- conv_cfg=self.conv_cfg,
- norm_cfg=self.norm_cfg)
- self.conv_logits = nn.Conv2d(conv_out_channels, self.num_classes, 1)
- if ignore_label:
- loss_seg['ignore_index'] = ignore_label
- if loss_weight:
- loss_seg['loss_weight'] = loss_weight
- if ignore_label or loss_weight:
- warnings.warn('``ignore_label`` and ``loss_weight`` would be '
- 'deprecated soon. Please set ``ingore_index`` and '
- '``loss_weight`` in ``loss_seg`` instead.')
- self.criterion = MODELS.build(loss_seg)
- def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor]:
- """Forward function.
- Args:
- feats (tuple[Tensor]): Multi scale feature maps.
- Returns:
- tuple[Tensor]:
- - mask_preds (Tensor): Predicted mask logits.
- - x (Tensor): Fused feature.
- """
- x = self.lateral_convs[self.fusion_level](feats[self.fusion_level])
- fused_size = tuple(x.shape[-2:])
- for i, feat in enumerate(feats):
- if i != self.fusion_level:
- feat = F.interpolate(
- feat, size=fused_size, mode='bilinear', align_corners=True)
- # fix runtime error of "+=" inplace operation in PyTorch 1.10
- x = x + self.lateral_convs[i](feat)
- for i in range(self.num_convs):
- x = self.convs[i](x)
- mask_preds = self.conv_logits(x)
- x = self.conv_embedding(x)
- return mask_preds, x
- def loss(self, mask_preds: Tensor, labels: Tensor) -> Tensor:
- """Loss function.
- Args:
- mask_preds (Tensor): Predicted mask logits.
- labels (Tensor): Ground truth.
- Returns:
- Tensor: Semantic segmentation loss.
- """
- labels = F.interpolate(
- labels.float(), scale_factor=self.seg_scale_factor, mode='nearest')
- labels = labels.squeeze(1).long()
- loss_semantic_seg = self.criterion(mask_preds, labels)
- return loss_semantic_seg
|