se_layer.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. from mmcv.cnn import ConvModule
  5. from mmengine.model import BaseModule
  6. from mmengine.utils import digit_version, is_tuple_of
  7. from torch import Tensor
  8. from mmdet.utils import MultiConfig, OptConfigType, OptMultiConfig
  9. class SELayer(BaseModule):
  10. """Squeeze-and-Excitation Module.
  11. Args:
  12. channels (int): The input (and output) channels of the SE layer.
  13. ratio (int): Squeeze ratio in SELayer, the intermediate channel will be
  14. ``int(channels/ratio)``. Defaults to 16.
  15. conv_cfg (None or dict): Config dict for convolution layer.
  16. Defaults to None, which means using conv2d.
  17. act_cfg (dict or Sequence[dict]): Config dict for activation layer.
  18. If act_cfg is a dict, two activation layers will be configurated
  19. by this dict. If act_cfg is a sequence of dicts, the first
  20. activation layer will be configurated by the first dict and the
  21. second activation layer will be configurated by the second dict.
  22. Defaults to (dict(type='ReLU'), dict(type='Sigmoid'))
  23. init_cfg (dict or list[dict], optional): Initialization config dict.
  24. Defaults to None
  25. """
  26. def __init__(self,
  27. channels: int,
  28. ratio: int = 16,
  29. conv_cfg: OptConfigType = None,
  30. act_cfg: MultiConfig = (dict(type='ReLU'),
  31. dict(type='Sigmoid')),
  32. init_cfg: OptMultiConfig = None) -> None:
  33. super().__init__(init_cfg=init_cfg)
  34. if isinstance(act_cfg, dict):
  35. act_cfg = (act_cfg, act_cfg)
  36. assert len(act_cfg) == 2
  37. assert is_tuple_of(act_cfg, dict)
  38. self.global_avgpool = nn.AdaptiveAvgPool2d(1)
  39. self.conv1 = ConvModule(
  40. in_channels=channels,
  41. out_channels=int(channels / ratio),
  42. kernel_size=1,
  43. stride=1,
  44. conv_cfg=conv_cfg,
  45. act_cfg=act_cfg[0])
  46. self.conv2 = ConvModule(
  47. in_channels=int(channels / ratio),
  48. out_channels=channels,
  49. kernel_size=1,
  50. stride=1,
  51. conv_cfg=conv_cfg,
  52. act_cfg=act_cfg[1])
  53. def forward(self, x: Tensor) -> Tensor:
  54. """Forward function for SELayer."""
  55. out = self.global_avgpool(x)
  56. out = self.conv1(out)
  57. out = self.conv2(out)
  58. return x * out
  59. class DyReLU(BaseModule):
  60. """Dynamic ReLU (DyReLU) module.
  61. See `Dynamic ReLU <https://arxiv.org/abs/2003.10027>`_ for details.
  62. Current implementation is specialized for task-aware attention in DyHead.
  63. HSigmoid arguments in default act_cfg follow DyHead official code.
  64. https://github.com/microsoft/DynamicHead/blob/master/dyhead/dyrelu.py
  65. Args:
  66. channels (int): The input (and output) channels of DyReLU module.
  67. ratio (int): Squeeze ratio in Squeeze-and-Excitation-like module,
  68. the intermediate channel will be ``int(channels/ratio)``.
  69. Defaults to 4.
  70. conv_cfg (None or dict): Config dict for convolution layer.
  71. Defaults to None, which means using conv2d.
  72. act_cfg (dict or Sequence[dict]): Config dict for activation layer.
  73. If act_cfg is a dict, two activation layers will be configurated
  74. by this dict. If act_cfg is a sequence of dicts, the first
  75. activation layer will be configurated by the first dict and the
  76. second activation layer will be configurated by the second dict.
  77. Defaults to (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0,
  78. divisor=6.0))
  79. init_cfg (dict or list[dict], optional): Initialization config dict.
  80. Defaults to None
  81. """
  82. def __init__(self,
  83. channels: int,
  84. ratio: int = 4,
  85. conv_cfg: OptConfigType = None,
  86. act_cfg: MultiConfig = (dict(type='ReLU'),
  87. dict(
  88. type='HSigmoid',
  89. bias=3.0,
  90. divisor=6.0)),
  91. init_cfg: OptMultiConfig = None) -> None:
  92. super().__init__(init_cfg=init_cfg)
  93. if isinstance(act_cfg, dict):
  94. act_cfg = (act_cfg, act_cfg)
  95. assert len(act_cfg) == 2
  96. assert is_tuple_of(act_cfg, dict)
  97. self.channels = channels
  98. self.expansion = 4 # for a1, b1, a2, b2
  99. self.global_avgpool = nn.AdaptiveAvgPool2d(1)
  100. self.conv1 = ConvModule(
  101. in_channels=channels,
  102. out_channels=int(channels / ratio),
  103. kernel_size=1,
  104. stride=1,
  105. conv_cfg=conv_cfg,
  106. act_cfg=act_cfg[0])
  107. self.conv2 = ConvModule(
  108. in_channels=int(channels / ratio),
  109. out_channels=channels * self.expansion,
  110. kernel_size=1,
  111. stride=1,
  112. conv_cfg=conv_cfg,
  113. act_cfg=act_cfg[1])
  114. def forward(self, x: Tensor) -> Tensor:
  115. """Forward function."""
  116. coeffs = self.global_avgpool(x)
  117. coeffs = self.conv1(coeffs)
  118. coeffs = self.conv2(coeffs) - 0.5 # value range: [-0.5, 0.5]
  119. a1, b1, a2, b2 = torch.split(coeffs, self.channels, dim=1)
  120. a1 = a1 * 2.0 + 1.0 # [-1.0, 1.0] + 1.0
  121. a2 = a2 * 2.0 # [-1.0, 1.0]
  122. out = torch.max(x * a1 + b1, x * a2 + b2)
  123. return out
  124. class ChannelAttention(BaseModule):
  125. """Channel attention Module.
  126. Args:
  127. channels (int): The input (and output) channels of the attention layer.
  128. init_cfg (dict or list[dict], optional): Initialization config dict.
  129. Defaults to None
  130. """
  131. def __init__(self, channels: int, init_cfg: OptMultiConfig = None) -> None:
  132. super().__init__(init_cfg=init_cfg)
  133. self.global_avgpool = nn.AdaptiveAvgPool2d(1)
  134. self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
  135. if digit_version(torch.__version__) < (1, 7, 0):
  136. self.act = nn.Hardsigmoid()
  137. else:
  138. self.act = nn.Hardsigmoid(inplace=True)
  139. def forward(self, x: Tensor) -> Tensor:
  140. """Forward function for ChannelAttention."""
  141. with torch.cuda.amp.autocast(enabled=False):
  142. out = self.global_avgpool(x)
  143. out = self.fc(out)
  144. out = self.act(out)
  145. return x * out