123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import pytest
- import torch
- import torch.nn.functional as F
- from mmengine.model import constant_init
- from mmdet.models.layers import DyReLU, SELayer
- def test_se_layer():
- with pytest.raises(AssertionError):
- # act_cfg sequence length must equal to 2
- SELayer(channels=32, act_cfg=(dict(type='ReLU'), ))
- with pytest.raises(AssertionError):
- # act_cfg sequence must be a tuple of dict
- SELayer(channels=32, act_cfg=[dict(type='ReLU'), dict(type='ReLU')])
- # Test SELayer forward
- layer = SELayer(channels=32)
- layer.init_weights()
- layer.train()
- x = torch.randn((1, 32, 10, 10))
- x_out = layer(x)
- assert x_out.shape == torch.Size((1, 32, 10, 10))
- def test_dyrelu():
- with pytest.raises(AssertionError):
- # act_cfg sequence length must equal to 2
- DyReLU(channels=32, act_cfg=(dict(type='ReLU'), ))
- with pytest.raises(AssertionError):
- # act_cfg sequence must be a tuple of dict
- DyReLU(channels=32, act_cfg=[dict(type='ReLU'), dict(type='ReLU')])
- # Test DyReLU forward
- layer = DyReLU(channels=32)
- layer.init_weights()
- layer.train()
- x = torch.randn((1, 32, 10, 10))
- x_out = layer(x)
- assert x_out.shape == torch.Size((1, 32, 10, 10))
- # DyReLU should act as standard (static) ReLU
- # when eliminating the effect of SE-like module
- layer = DyReLU(channels=32)
- constant_init(layer.conv2.conv, 0)
- layer.train()
- x = torch.randn((1, 32, 10, 10))
- x_out = layer(x)
- relu_out = F.relu(x)
- assert torch.equal(x_out, relu_out)
|