coarse_mask_head.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from mmcv.cnn import ConvModule, Linear
  3. from mmengine.model import ModuleList
  4. from torch import Tensor
  5. from mmdet.registry import MODELS
  6. from mmdet.utils import MultiConfig
  7. from .fcn_mask_head import FCNMaskHead
  8. @MODELS.register_module()
  9. class CoarseMaskHead(FCNMaskHead):
  10. """Coarse mask head used in PointRend.
  11. Compared with standard ``FCNMaskHead``, ``CoarseMaskHead`` will downsample
  12. the input feature map instead of upsample it.
  13. Args:
  14. num_convs (int): Number of conv layers in the head. Defaults to 0.
  15. num_fcs (int): Number of fc layers in the head. Defaults to 2.
  16. fc_out_channels (int): Number of output channels of fc layer.
  17. Defaults to 1024.
  18. downsample_factor (int): The factor that feature map is downsampled by.
  19. Defaults to 2.
  20. init_cfg (dict or list[dict], optional): Initialization config dict.
  21. """
  22. def __init__(self,
  23. num_convs: int = 0,
  24. num_fcs: int = 2,
  25. fc_out_channels: int = 1024,
  26. downsample_factor: int = 2,
  27. init_cfg: MultiConfig = dict(
  28. type='Xavier',
  29. override=[
  30. dict(name='fcs'),
  31. dict(type='Constant', val=0.001, name='fc_logits')
  32. ]),
  33. *arg,
  34. **kwarg) -> None:
  35. super().__init__(
  36. *arg,
  37. num_convs=num_convs,
  38. upsample_cfg=dict(type=None),
  39. init_cfg=None,
  40. **kwarg)
  41. self.init_cfg = init_cfg
  42. self.num_fcs = num_fcs
  43. assert self.num_fcs > 0
  44. self.fc_out_channels = fc_out_channels
  45. self.downsample_factor = downsample_factor
  46. assert self.downsample_factor >= 1
  47. # remove conv_logit
  48. delattr(self, 'conv_logits')
  49. if downsample_factor > 1:
  50. downsample_in_channels = (
  51. self.conv_out_channels
  52. if self.num_convs > 0 else self.in_channels)
  53. self.downsample_conv = ConvModule(
  54. downsample_in_channels,
  55. self.conv_out_channels,
  56. kernel_size=downsample_factor,
  57. stride=downsample_factor,
  58. padding=0,
  59. conv_cfg=self.conv_cfg,
  60. norm_cfg=self.norm_cfg)
  61. else:
  62. self.downsample_conv = None
  63. self.output_size = (self.roi_feat_size[0] // downsample_factor,
  64. self.roi_feat_size[1] // downsample_factor)
  65. self.output_area = self.output_size[0] * self.output_size[1]
  66. last_layer_dim = self.conv_out_channels * self.output_area
  67. self.fcs = ModuleList()
  68. for i in range(num_fcs):
  69. fc_in_channels = (
  70. last_layer_dim if i == 0 else self.fc_out_channels)
  71. self.fcs.append(Linear(fc_in_channels, self.fc_out_channels))
  72. last_layer_dim = self.fc_out_channels
  73. output_channels = self.num_classes * self.output_area
  74. self.fc_logits = Linear(last_layer_dim, output_channels)
  75. def init_weights(self) -> None:
  76. """Initialize weights."""
  77. super(FCNMaskHead, self).init_weights()
  78. def forward(self, x: Tensor) -> Tensor:
  79. """Forward features from the upstream network.
  80. Args:
  81. x (Tensor): Extract mask RoI features.
  82. Returns:
  83. Tensor: Predicted foreground masks.
  84. """
  85. for conv in self.convs:
  86. x = conv(x)
  87. if self.downsample_conv is not None:
  88. x = self.downsample_conv(x)
  89. x = x.flatten(1)
  90. for fc in self.fcs:
  91. x = self.relu(fc(x))
  92. mask_preds = self.fc_logits(x).view(
  93. x.size(0), self.num_classes, *self.output_size)
  94. return mask_preds