123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from mmcv.cnn import ConvModule, Linear
- from mmengine.model import ModuleList
- from torch import Tensor
- from mmdet.registry import MODELS
- from mmdet.utils import MultiConfig
- from .fcn_mask_head import FCNMaskHead
- @MODELS.register_module()
- class CoarseMaskHead(FCNMaskHead):
- """Coarse mask head used in PointRend.
- Compared with standard ``FCNMaskHead``, ``CoarseMaskHead`` will downsample
- the input feature map instead of upsample it.
- Args:
- num_convs (int): Number of conv layers in the head. Defaults to 0.
- num_fcs (int): Number of fc layers in the head. Defaults to 2.
- fc_out_channels (int): Number of output channels of fc layer.
- Defaults to 1024.
- downsample_factor (int): The factor that feature map is downsampled by.
- Defaults to 2.
- init_cfg (dict or list[dict], optional): Initialization config dict.
- """
- def __init__(self,
- num_convs: int = 0,
- num_fcs: int = 2,
- fc_out_channels: int = 1024,
- downsample_factor: int = 2,
- init_cfg: MultiConfig = dict(
- type='Xavier',
- override=[
- dict(name='fcs'),
- dict(type='Constant', val=0.001, name='fc_logits')
- ]),
- *arg,
- **kwarg) -> None:
- super().__init__(
- *arg,
- num_convs=num_convs,
- upsample_cfg=dict(type=None),
- init_cfg=None,
- **kwarg)
- self.init_cfg = init_cfg
- self.num_fcs = num_fcs
- assert self.num_fcs > 0
- self.fc_out_channels = fc_out_channels
- self.downsample_factor = downsample_factor
- assert self.downsample_factor >= 1
- # remove conv_logit
- delattr(self, 'conv_logits')
- if downsample_factor > 1:
- downsample_in_channels = (
- self.conv_out_channels
- if self.num_convs > 0 else self.in_channels)
- self.downsample_conv = ConvModule(
- downsample_in_channels,
- self.conv_out_channels,
- kernel_size=downsample_factor,
- stride=downsample_factor,
- padding=0,
- conv_cfg=self.conv_cfg,
- norm_cfg=self.norm_cfg)
- else:
- self.downsample_conv = None
- self.output_size = (self.roi_feat_size[0] // downsample_factor,
- self.roi_feat_size[1] // downsample_factor)
- self.output_area = self.output_size[0] * self.output_size[1]
- last_layer_dim = self.conv_out_channels * self.output_area
- self.fcs = ModuleList()
- for i in range(num_fcs):
- fc_in_channels = (
- last_layer_dim if i == 0 else self.fc_out_channels)
- self.fcs.append(Linear(fc_in_channels, self.fc_out_channels))
- last_layer_dim = self.fc_out_channels
- output_channels = self.num_classes * self.output_area
- self.fc_logits = Linear(last_layer_dim, output_channels)
- def init_weights(self) -> None:
- """Initialize weights."""
- super(FCNMaskHead, self).init_weights()
- def forward(self, x: Tensor) -> Tensor:
- """Forward features from the upstream network.
- Args:
- x (Tensor): Extract mask RoI features.
- Returns:
- Tensor: Predicted foreground masks.
- """
- for conv in self.convs:
- x = conv(x)
- if self.downsample_conv is not None:
- x = self.downsample_conv(x)
- x = x.flatten(1)
- for fc in self.fcs:
- x = self.relu(fc(x))
- mask_preds = self.fc_logits(x).view(
- x.size(0), self.num_classes, *self.output_size)
- return mask_preds
|