test_se_layer.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. import torch.nn.functional as F
  5. from mmengine.model import constant_init
  6. from mmdet.models.layers import DyReLU, SELayer
  7. def test_se_layer():
  8. with pytest.raises(AssertionError):
  9. # act_cfg sequence length must equal to 2
  10. SELayer(channels=32, act_cfg=(dict(type='ReLU'), ))
  11. with pytest.raises(AssertionError):
  12. # act_cfg sequence must be a tuple of dict
  13. SELayer(channels=32, act_cfg=[dict(type='ReLU'), dict(type='ReLU')])
  14. # Test SELayer forward
  15. layer = SELayer(channels=32)
  16. layer.init_weights()
  17. layer.train()
  18. x = torch.randn((1, 32, 10, 10))
  19. x_out = layer(x)
  20. assert x_out.shape == torch.Size((1, 32, 10, 10))
  21. def test_dyrelu():
  22. with pytest.raises(AssertionError):
  23. # act_cfg sequence length must equal to 2
  24. DyReLU(channels=32, act_cfg=(dict(type='ReLU'), ))
  25. with pytest.raises(AssertionError):
  26. # act_cfg sequence must be a tuple of dict
  27. DyReLU(channels=32, act_cfg=[dict(type='ReLU'), dict(type='ReLU')])
  28. # Test DyReLU forward
  29. layer = DyReLU(channels=32)
  30. layer.init_weights()
  31. layer.train()
  32. x = torch.randn((1, 32, 10, 10))
  33. x_out = layer(x)
  34. assert x_out.shape == torch.Size((1, 32, 10, 10))
  35. # DyReLU should act as standard (static) ReLU
  36. # when eliminating the effect of SE-like module
  37. layer = DyReLU(channels=32)
  38. constant_init(layer.conv2.conv, 0)
  39. layer.train()
  40. x = torch.randn((1, 32, 10, 10))
  41. x_out = layer(x)
  42. relu_out = F.relu(x)
  43. assert torch.equal(x_out, relu_out)