scnet_semantic_head.py 988 B

12345678910111213141516171819202122232425262728
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from mmdet.models.layers import ResLayer, SimplifiedBasicBlock
  3. from mmdet.registry import MODELS
  4. from .fused_semantic_head import FusedSemanticHead
  5. @MODELS.register_module()
  6. class SCNetSemanticHead(FusedSemanticHead):
  7. """Mask head for `SCNet <https://arxiv.org/abs/2012.10150>`_.
  8. Args:
  9. conv_to_res (bool, optional): if True, change the conv layers to
  10. ``SimplifiedBasicBlock``.
  11. """
  12. def __init__(self, conv_to_res: bool = True, **kwargs) -> None:
  13. super().__init__(**kwargs)
  14. self.conv_to_res = conv_to_res
  15. if self.conv_to_res:
  16. num_res_blocks = self.num_convs // 2
  17. self.convs = ResLayer(
  18. SimplifiedBasicBlock,
  19. self.in_channels,
  20. self.conv_out_channels,
  21. num_res_blocks,
  22. conv_cfg=self.conv_cfg,
  23. norm_cfg=self.norm_cfg)
  24. self.num_convs = num_res_blocks