activations.py 557 B

12345678910111213141516171819202122
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. from mmengine.utils import digit_version
  5. from mmdet.registry import MODELS
  6. if digit_version(torch.__version__) >= digit_version('1.7.0'):
  7. from torch.nn import SiLU
  8. else:
  9. class SiLU(nn.Module):
  10. """Sigmoid Weighted Liner Unit."""
  11. def __init__(self, inplace=True):
  12. super().__init__()
  13. def forward(self, inputs) -> torch.Tensor:
  14. return inputs * torch.sigmoid(inputs)
  15. MODELS.register_module(module=SiLU, name='SiLU')