1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List, Optional, Union
- from mmcv.cnn import ConvModule
- from torch import Tensor
- from mmdet.registry import MODELS
- from .fcn_mask_head import FCNMaskHead
- @MODELS.register_module()
- class HTCMaskHead(FCNMaskHead):
- """Mask head for HTC.
- Args:
- with_conv_res (bool): Whether add conv layer for ``res_feat``.
- Defaults to True.
- """
- def __init__(self, with_conv_res: bool = True, *args, **kwargs) -> None:
- super().__init__(*args, **kwargs)
- self.with_conv_res = with_conv_res
- if self.with_conv_res:
- self.conv_res = ConvModule(
- self.conv_out_channels,
- self.conv_out_channels,
- 1,
- conv_cfg=self.conv_cfg,
- norm_cfg=self.norm_cfg)
- def forward(self,
- x: Tensor,
- res_feat: Optional[Tensor] = None,
- return_logits: bool = True,
- return_feat: bool = True) -> Union[Tensor, List[Tensor]]:
- """
- Args:
- x (Tensor): Feature map.
- res_feat (Tensor, optional): Feature for residual connection.
- Defaults to None.
- return_logits (bool): Whether return mask logits. Defaults to True.
- return_feat (bool): Whether return feature map. Defaults to True.
- Returns:
- Union[Tensor, List[Tensor]]: The return result is one of three
- results: res_feat, logits, or [logits, res_feat].
- """
- assert not (not return_logits and not return_feat)
- if res_feat is not None:
- assert self.with_conv_res
- res_feat = self.conv_res(res_feat)
- x = x + res_feat
- for conv in self.convs:
- x = conv(x)
- res_feat = x
- outs = []
- if return_logits:
- x = self.upsample(x)
- if self.upsample_method == 'deconv':
- x = self.relu(x)
- mask_preds = self.conv_logits(x)
- outs.append(mask_preds)
- if return_feat:
- outs.append(res_feat)
- return outs if len(outs) > 1 else outs[0]
|