test_conv_upsample.py 629 B

123456789101112131415161718192021222324
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from mmdet.models.layers import ConvUpsample
  5. @pytest.mark.parametrize('num_layers', [0, 1, 2])
  6. def test_conv_upsample(num_layers):
  7. num_upsample = num_layers if num_layers > 0 else 0
  8. num_layers = num_layers if num_layers > 0 else 1
  9. layer = ConvUpsample(
  10. 10,
  11. 5,
  12. num_layers=num_layers,
  13. num_upsample=num_upsample,
  14. conv_cfg=None,
  15. norm_cfg=None)
  16. size = 5
  17. x = torch.randn((1, 10, size, size))
  18. size = size * pow(2, num_upsample)
  19. x = layer(x)
  20. assert x.shape[-2:] == (size, size)