panoptic_fpn_head.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmengine.model import ModuleList
  7. from torch import Tensor
  8. from mmdet.registry import MODELS
  9. from mmdet.structures import SampleList
  10. from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
  11. from ..layers import ConvUpsample
  12. from ..utils import interpolate_as
  13. from .base_semantic_head import BaseSemanticHead
  14. @MODELS.register_module()
  15. class PanopticFPNHead(BaseSemanticHead):
  16. """PanopticFPNHead used in Panoptic FPN.
  17. In this head, the number of output channels is ``num_stuff_classes
  18. + 1``, including all stuff classes and one thing class. The stuff
  19. classes will be reset from ``0`` to ``num_stuff_classes - 1``, the
  20. thing classes will be merged to ``num_stuff_classes``-th channel.
  21. Arg:
  22. num_things_classes (int): Number of thing classes. Default: 80.
  23. num_stuff_classes (int): Number of stuff classes. Default: 53.
  24. in_channels (int): Number of channels in the input feature
  25. map.
  26. inner_channels (int): Number of channels in inner features.
  27. start_level (int): The start level of the input features
  28. used in PanopticFPN.
  29. end_level (int): The end level of the used features, the
  30. ``end_level``-th layer will not be used.
  31. conv_cfg (Optional[Union[ConfigDict, dict]]): Dictionary to construct
  32. and config conv layer.
  33. norm_cfg (Union[ConfigDict, dict]): Dictionary to construct and config
  34. norm layer. Use ``GN`` by default.
  35. init_cfg (Optional[Union[ConfigDict, dict]]): Initialization config
  36. dict.
  37. loss_seg (Union[ConfigDict, dict]): the loss of the semantic head.
  38. """
  39. def __init__(self,
  40. num_things_classes: int = 80,
  41. num_stuff_classes: int = 53,
  42. in_channels: int = 256,
  43. inner_channels: int = 128,
  44. start_level: int = 0,
  45. end_level: int = 4,
  46. conv_cfg: OptConfigType = None,
  47. norm_cfg: ConfigType = dict(
  48. type='GN', num_groups=32, requires_grad=True),
  49. loss_seg: ConfigType = dict(
  50. type='CrossEntropyLoss', ignore_index=-1,
  51. loss_weight=1.0),
  52. init_cfg: OptMultiConfig = None) -> None:
  53. seg_rescale_factor = 1 / 2**(start_level + 2)
  54. super().__init__(
  55. num_classes=num_stuff_classes + 1,
  56. seg_rescale_factor=seg_rescale_factor,
  57. loss_seg=loss_seg,
  58. init_cfg=init_cfg)
  59. self.num_things_classes = num_things_classes
  60. self.num_stuff_classes = num_stuff_classes
  61. # Used feature layers are [start_level, end_level)
  62. self.start_level = start_level
  63. self.end_level = end_level
  64. self.num_stages = end_level - start_level
  65. self.inner_channels = inner_channels
  66. self.conv_upsample_layers = ModuleList()
  67. for i in range(start_level, end_level):
  68. self.conv_upsample_layers.append(
  69. ConvUpsample(
  70. in_channels,
  71. inner_channels,
  72. num_layers=i if i > 0 else 1,
  73. num_upsample=i if i > 0 else 0,
  74. conv_cfg=conv_cfg,
  75. norm_cfg=norm_cfg,
  76. ))
  77. self.conv_logits = nn.Conv2d(inner_channels, self.num_classes, 1)
  78. def _set_things_to_void(self, gt_semantic_seg: Tensor) -> Tensor:
  79. """Merge thing classes to one class.
  80. In PanopticFPN, the background labels will be reset from `0` to
  81. `self.num_stuff_classes-1`, the foreground labels will be merged to
  82. `self.num_stuff_classes`-th channel.
  83. """
  84. gt_semantic_seg = gt_semantic_seg.int()
  85. fg_mask = gt_semantic_seg < self.num_things_classes
  86. bg_mask = (gt_semantic_seg >= self.num_things_classes) * (
  87. gt_semantic_seg < self.num_things_classes + self.num_stuff_classes)
  88. new_gt_seg = torch.clone(gt_semantic_seg)
  89. new_gt_seg = torch.where(bg_mask,
  90. gt_semantic_seg - self.num_things_classes,
  91. new_gt_seg)
  92. new_gt_seg = torch.where(fg_mask,
  93. fg_mask.int() * self.num_stuff_classes,
  94. new_gt_seg)
  95. return new_gt_seg
  96. def loss(self, x: Union[Tensor, Tuple[Tensor]],
  97. batch_data_samples: SampleList) -> Dict[str, Tensor]:
  98. """
  99. Args:
  100. x (Union[Tensor, Tuple[Tensor]]): Feature maps.
  101. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  102. data samples. It usually includes information such
  103. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  104. Returns:
  105. Dict[str, Tensor]: The loss of semantic head.
  106. """
  107. seg_preds = self(x)['seg_preds']
  108. gt_semantic_segs = [
  109. data_sample.gt_sem_seg.sem_seg
  110. for data_sample in batch_data_samples
  111. ]
  112. gt_semantic_segs = torch.stack(gt_semantic_segs)
  113. if self.seg_rescale_factor != 1.0:
  114. gt_semantic_segs = F.interpolate(
  115. gt_semantic_segs.float(),
  116. scale_factor=self.seg_rescale_factor,
  117. mode='nearest').squeeze(1)
  118. # Things classes will be merged to one class in PanopticFPN.
  119. gt_semantic_segs = self._set_things_to_void(gt_semantic_segs)
  120. if seg_preds.shape[-2:] != gt_semantic_segs.shape[-2:]:
  121. seg_preds = interpolate_as(seg_preds, gt_semantic_segs)
  122. seg_preds = seg_preds.permute((0, 2, 3, 1))
  123. loss_seg = self.loss_seg(
  124. seg_preds.reshape(-1, self.num_classes), # => [NxHxW, C]
  125. gt_semantic_segs.reshape(-1).long())
  126. return dict(loss_seg=loss_seg)
  127. def init_weights(self) -> None:
  128. """Initialize weights."""
  129. super().init_weights()
  130. nn.init.normal_(self.conv_logits.weight.data, 0, 0.01)
  131. self.conv_logits.bias.data.zero_()
  132. def forward(self, x: Tuple[Tensor]) -> Dict[str, Tensor]:
  133. """Forward.
  134. Args:
  135. x (Tuple[Tensor]): Multi scale Feature maps.
  136. Returns:
  137. dict[str, Tensor]: semantic segmentation predictions and
  138. feature maps.
  139. """
  140. # the number of subnets must be not more than
  141. # the length of features.
  142. assert self.num_stages <= len(x)
  143. feats = []
  144. for i, layer in enumerate(self.conv_upsample_layers):
  145. f = layer(x[self.start_level + i])
  146. feats.append(f)
  147. seg_feats = torch.sum(torch.stack(feats, dim=0), dim=0)
  148. seg_preds = self.conv_logits(seg_feats)
  149. out = dict(seg_preds=seg_preds, seg_feats=seg_feats)
  150. return out