htc_mask_head.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Union
  3. from mmcv.cnn import ConvModule
  4. from torch import Tensor
  5. from mmdet.registry import MODELS
  6. from .fcn_mask_head import FCNMaskHead
  7. @MODELS.register_module()
  8. class HTCMaskHead(FCNMaskHead):
  9. """Mask head for HTC.
  10. Args:
  11. with_conv_res (bool): Whether add conv layer for ``res_feat``.
  12. Defaults to True.
  13. """
  14. def __init__(self, with_conv_res: bool = True, *args, **kwargs) -> None:
  15. super().__init__(*args, **kwargs)
  16. self.with_conv_res = with_conv_res
  17. if self.with_conv_res:
  18. self.conv_res = ConvModule(
  19. self.conv_out_channels,
  20. self.conv_out_channels,
  21. 1,
  22. conv_cfg=self.conv_cfg,
  23. norm_cfg=self.norm_cfg)
  24. def forward(self,
  25. x: Tensor,
  26. res_feat: Optional[Tensor] = None,
  27. return_logits: bool = True,
  28. return_feat: bool = True) -> Union[Tensor, List[Tensor]]:
  29. """
  30. Args:
  31. x (Tensor): Feature map.
  32. res_feat (Tensor, optional): Feature for residual connection.
  33. Defaults to None.
  34. return_logits (bool): Whether return mask logits. Defaults to True.
  35. return_feat (bool): Whether return feature map. Defaults to True.
  36. Returns:
  37. Union[Tensor, List[Tensor]]: The return result is one of three
  38. results: res_feat, logits, or [logits, res_feat].
  39. """
  40. assert not (not return_logits and not return_feat)
  41. if res_feat is not None:
  42. assert self.with_conv_res
  43. res_feat = self.conv_res(res_feat)
  44. x = x + res_feat
  45. for conv in self.convs:
  46. x = conv(x)
  47. res_feat = x
  48. outs = []
  49. if return_logits:
  50. x = self.upsample(x)
  51. if self.upsample_method == 'deconv':
  52. x = self.relu(x)
  53. mask_preds = self.conv_logits(x)
  54. outs.append(mask_preds)
  55. if return_feat:
  56. outs.append(res_feat)
  57. return outs if len(outs) > 1 else outs[0]