test_model_misc.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import torch
  4. from torch.autograd import gradcheck
  5. from mmdet.models.utils import interpolate_as, sigmoid_geometric_mean
  6. def test_interpolate_as():
  7. source = torch.rand((1, 5, 4, 4))
  8. target = torch.rand((1, 1, 16, 16))
  9. # Test 4D source and target
  10. result = interpolate_as(source, target)
  11. assert result.shape == torch.Size((1, 5, 16, 16))
  12. # Test 3D target
  13. result = interpolate_as(source, target.squeeze(0))
  14. assert result.shape == torch.Size((1, 5, 16, 16))
  15. # Test 3D source
  16. result = interpolate_as(source.squeeze(0), target)
  17. assert result.shape == torch.Size((5, 16, 16))
  18. # Test type(target) == np.ndarray
  19. target = np.random.rand(16, 16)
  20. result = interpolate_as(source.squeeze(0), target)
  21. assert result.shape == torch.Size((5, 16, 16))
  22. def test_sigmoid_geometric_mean():
  23. x = torch.randn(20, 20, dtype=torch.double, requires_grad=True)
  24. y = torch.randn(20, 20, dtype=torch.double, requires_grad=True)
  25. inputs = (x, y)
  26. test = gradcheck(sigmoid_geometric_mean, inputs, eps=1e-6, atol=1e-4)
  27. assert test