# Copyright (c) OpenMMLab. All rights reserved. from mmcv.cnn import ConvModule, Linear from mmengine.model import ModuleList from torch import Tensor from mmdet.registry import MODELS from mmdet.utils import MultiConfig from .fcn_mask_head import FCNMaskHead @MODELS.register_module() class CoarseMaskHead(FCNMaskHead): """Coarse mask head used in PointRend. Compared with standard ``FCNMaskHead``, ``CoarseMaskHead`` will downsample the input feature map instead of upsample it. Args: num_convs (int): Number of conv layers in the head. Defaults to 0. num_fcs (int): Number of fc layers in the head. Defaults to 2. fc_out_channels (int): Number of output channels of fc layer. Defaults to 1024. downsample_factor (int): The factor that feature map is downsampled by. Defaults to 2. init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__(self, num_convs: int = 0, num_fcs: int = 2, fc_out_channels: int = 1024, downsample_factor: int = 2, init_cfg: MultiConfig = dict( type='Xavier', override=[ dict(name='fcs'), dict(type='Constant', val=0.001, name='fc_logits') ]), *arg, **kwarg) -> None: super().__init__( *arg, num_convs=num_convs, upsample_cfg=dict(type=None), init_cfg=None, **kwarg) self.init_cfg = init_cfg self.num_fcs = num_fcs assert self.num_fcs > 0 self.fc_out_channels = fc_out_channels self.downsample_factor = downsample_factor assert self.downsample_factor >= 1 # remove conv_logit delattr(self, 'conv_logits') if downsample_factor > 1: downsample_in_channels = ( self.conv_out_channels if self.num_convs > 0 else self.in_channels) self.downsample_conv = ConvModule( downsample_in_channels, self.conv_out_channels, kernel_size=downsample_factor, stride=downsample_factor, padding=0, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg) else: self.downsample_conv = None self.output_size = (self.roi_feat_size[0] // downsample_factor, self.roi_feat_size[1] // downsample_factor) self.output_area = self.output_size[0] * self.output_size[1] last_layer_dim = self.conv_out_channels * self.output_area self.fcs = ModuleList() for i in range(num_fcs): fc_in_channels = ( last_layer_dim if i == 0 else self.fc_out_channels) self.fcs.append(Linear(fc_in_channels, self.fc_out_channels)) last_layer_dim = self.fc_out_channels output_channels = self.num_classes * self.output_area self.fc_logits = Linear(last_layer_dim, output_channels) def init_weights(self) -> None: """Initialize weights.""" super(FCNMaskHead, self).init_weights() def forward(self, x: Tensor) -> Tensor: """Forward features from the upstream network. Args: x (Tensor): Extract mask RoI features. Returns: Tensor: Predicted foreground masks. """ for conv in self.convs: x = conv(x) if self.downsample_conv is not None: x = self.downsample_conv(x) x = x.flatten(1) for fc in self.fcs: x = self.relu(fc(x)) mask_preds = self.fc_logits(x).view( x.size(0), self.num_classes, *self.output_size) return mask_preds