123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Dict, Tuple, Union
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from mmengine.model import ModuleList
- from torch import Tensor
- from mmdet.registry import MODELS
- from mmdet.structures import SampleList
- from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
- from ..layers import ConvUpsample
- from ..utils import interpolate_as
- from .base_semantic_head import BaseSemanticHead
- @MODELS.register_module()
- class PanopticFPNHead(BaseSemanticHead):
- """PanopticFPNHead used in Panoptic FPN.
- In this head, the number of output channels is ``num_stuff_classes
- + 1``, including all stuff classes and one thing class. The stuff
- classes will be reset from ``0`` to ``num_stuff_classes - 1``, the
- thing classes will be merged to ``num_stuff_classes``-th channel.
- Arg:
- num_things_classes (int): Number of thing classes. Default: 80.
- num_stuff_classes (int): Number of stuff classes. Default: 53.
- in_channels (int): Number of channels in the input feature
- map.
- inner_channels (int): Number of channels in inner features.
- start_level (int): The start level of the input features
- used in PanopticFPN.
- end_level (int): The end level of the used features, the
- ``end_level``-th layer will not be used.
- conv_cfg (Optional[Union[ConfigDict, dict]]): Dictionary to construct
- and config conv layer.
- norm_cfg (Union[ConfigDict, dict]): Dictionary to construct and config
- norm layer. Use ``GN`` by default.
- init_cfg (Optional[Union[ConfigDict, dict]]): Initialization config
- dict.
- loss_seg (Union[ConfigDict, dict]): the loss of the semantic head.
- """
- def __init__(self,
- num_things_classes: int = 80,
- num_stuff_classes: int = 53,
- in_channels: int = 256,
- inner_channels: int = 128,
- start_level: int = 0,
- end_level: int = 4,
- conv_cfg: OptConfigType = None,
- norm_cfg: ConfigType = dict(
- type='GN', num_groups=32, requires_grad=True),
- loss_seg: ConfigType = dict(
- type='CrossEntropyLoss', ignore_index=-1,
- loss_weight=1.0),
- init_cfg: OptMultiConfig = None) -> None:
- seg_rescale_factor = 1 / 2**(start_level + 2)
- super().__init__(
- num_classes=num_stuff_classes + 1,
- seg_rescale_factor=seg_rescale_factor,
- loss_seg=loss_seg,
- init_cfg=init_cfg)
- self.num_things_classes = num_things_classes
- self.num_stuff_classes = num_stuff_classes
- # Used feature layers are [start_level, end_level)
- self.start_level = start_level
- self.end_level = end_level
- self.num_stages = end_level - start_level
- self.inner_channels = inner_channels
- self.conv_upsample_layers = ModuleList()
- for i in range(start_level, end_level):
- self.conv_upsample_layers.append(
- ConvUpsample(
- in_channels,
- inner_channels,
- num_layers=i if i > 0 else 1,
- num_upsample=i if i > 0 else 0,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- ))
- self.conv_logits = nn.Conv2d(inner_channels, self.num_classes, 1)
- def _set_things_to_void(self, gt_semantic_seg: Tensor) -> Tensor:
- """Merge thing classes to one class.
- In PanopticFPN, the background labels will be reset from `0` to
- `self.num_stuff_classes-1`, the foreground labels will be merged to
- `self.num_stuff_classes`-th channel.
- """
- gt_semantic_seg = gt_semantic_seg.int()
- fg_mask = gt_semantic_seg < self.num_things_classes
- bg_mask = (gt_semantic_seg >= self.num_things_classes) * (
- gt_semantic_seg < self.num_things_classes + self.num_stuff_classes)
- new_gt_seg = torch.clone(gt_semantic_seg)
- new_gt_seg = torch.where(bg_mask,
- gt_semantic_seg - self.num_things_classes,
- new_gt_seg)
- new_gt_seg = torch.where(fg_mask,
- fg_mask.int() * self.num_stuff_classes,
- new_gt_seg)
- return new_gt_seg
- def loss(self, x: Union[Tensor, Tuple[Tensor]],
- batch_data_samples: SampleList) -> Dict[str, Tensor]:
- """
- Args:
- x (Union[Tensor, Tuple[Tensor]]): Feature maps.
- batch_data_samples (list[:obj:`DetDataSample`]): The batch
- data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
- Returns:
- Dict[str, Tensor]: The loss of semantic head.
- """
- seg_preds = self(x)['seg_preds']
- gt_semantic_segs = [
- data_sample.gt_sem_seg.sem_seg
- for data_sample in batch_data_samples
- ]
- gt_semantic_segs = torch.stack(gt_semantic_segs)
- if self.seg_rescale_factor != 1.0:
- gt_semantic_segs = F.interpolate(
- gt_semantic_segs.float(),
- scale_factor=self.seg_rescale_factor,
- mode='nearest').squeeze(1)
- # Things classes will be merged to one class in PanopticFPN.
- gt_semantic_segs = self._set_things_to_void(gt_semantic_segs)
- if seg_preds.shape[-2:] != gt_semantic_segs.shape[-2:]:
- seg_preds = interpolate_as(seg_preds, gt_semantic_segs)
- seg_preds = seg_preds.permute((0, 2, 3, 1))
- loss_seg = self.loss_seg(
- seg_preds.reshape(-1, self.num_classes), # => [NxHxW, C]
- gt_semantic_segs.reshape(-1).long())
- return dict(loss_seg=loss_seg)
- def init_weights(self) -> None:
- """Initialize weights."""
- super().init_weights()
- nn.init.normal_(self.conv_logits.weight.data, 0, 0.01)
- self.conv_logits.bias.data.zero_()
- def forward(self, x: Tuple[Tensor]) -> Dict[str, Tensor]:
- """Forward.
- Args:
- x (Tuple[Tensor]): Multi scale Feature maps.
- Returns:
- dict[str, Tensor]: semantic segmentation predictions and
- feature maps.
- """
- # the number of subnets must be not more than
- # the length of features.
- assert self.num_stages <= len(x)
- feats = []
- for i, layer in enumerate(self.conv_upsample_layers):
- f = layer(x[self.start_level + i])
- feats.append(f)
- seg_feats = torch.sum(torch.stack(feats, dim=0), dim=0)
- seg_preds = self.conv_logits(seg_feats)
- out = dict(seg_preds=seg_preds, seg_feats=seg_feats)
- return out
|