feature_relay_head.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional
  3. import torch.nn as nn
  4. from mmengine.model import BaseModule
  5. from torch import Tensor
  6. from mmdet.registry import MODELS
  7. from mmdet.utils import MultiConfig
  8. @MODELS.register_module()
  9. class FeatureRelayHead(BaseModule):
  10. """Feature Relay Head used in `SCNet <https://arxiv.org/abs/2012.10150>`_.
  11. Args:
  12. in_channels (int): number of input channels. Defaults to 256.
  13. conv_out_channels (int): number of output channels before
  14. classification layer. Defaults to 256.
  15. roi_feat_size (int): roi feat size at box head. Default: 7.
  16. scale_factor (int): scale factor to match roi feat size
  17. at mask head. Defaults to 2.
  18. init_cfg (:obj:`ConfigDict` or dict or list[dict] or
  19. list[:obj:`ConfigDict`]): Initialization config dict. Defaults to
  20. dict(type='Kaiming', layer='Linear').
  21. """
  22. def __init__(
  23. self,
  24. in_channels: int = 1024,
  25. out_conv_channels: int = 256,
  26. roi_feat_size: int = 7,
  27. scale_factor: int = 2,
  28. init_cfg: MultiConfig = dict(type='Kaiming', layer='Linear')
  29. ) -> None:
  30. super().__init__(init_cfg=init_cfg)
  31. assert isinstance(roi_feat_size, int)
  32. self.in_channels = in_channels
  33. self.out_conv_channels = out_conv_channels
  34. self.roi_feat_size = roi_feat_size
  35. self.out_channels = (roi_feat_size**2) * out_conv_channels
  36. self.scale_factor = scale_factor
  37. self.fp16_enabled = False
  38. self.fc = nn.Linear(self.in_channels, self.out_channels)
  39. self.upsample = nn.Upsample(
  40. scale_factor=scale_factor, mode='bilinear', align_corners=True)
  41. def forward(self, x: Tensor) -> Optional[Tensor]:
  42. """Forward function.
  43. Args:
  44. x (Tensor): Input feature.
  45. Returns:
  46. Optional[Tensor]: Output feature. When the first dim of input is
  47. 0, None is returned.
  48. """
  49. N, _ = x.shape
  50. if N > 0:
  51. out_C = self.out_conv_channels
  52. out_HW = self.roi_feat_size
  53. x = self.fc(x)
  54. x = x.reshape(N, out_C, out_HW, out_HW)
  55. x = self.upsample(x)
  56. return x
  57. return None