test_memory.py 4.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import numpy as np
  2. import pytest
  3. import torch
  4. from mmdet.utils import AvoidOOM
  5. from mmdet.utils.memory import cast_tensor_type
  6. def test_avoidoom():
  7. tensor = torch.from_numpy(np.random.random((20, 20)))
  8. if torch.cuda.is_available():
  9. tensor = tensor.cuda()
  10. # get default result
  11. default_result = torch.mm(tensor, tensor.transpose(1, 0))
  12. # when not occurred OOM error
  13. AvoidCudaOOM = AvoidOOM()
  14. result = AvoidCudaOOM.retry_if_cuda_oom(torch.mm)(tensor,
  15. tensor.transpose(
  16. 1, 0))
  17. assert default_result.device == result.device and \
  18. default_result.dtype == result.dtype and \
  19. torch.equal(default_result, result)
  20. # calculate with fp16 and convert back to source type
  21. AvoidCudaOOM = AvoidOOM(test=True)
  22. result = AvoidCudaOOM.retry_if_cuda_oom(torch.mm)(tensor,
  23. tensor.transpose(
  24. 1, 0))
  25. assert default_result.device == result.device and \
  26. default_result.dtype == result.dtype and \
  27. torch.allclose(default_result, result, 1e-3)
  28. # calculate on cpu and convert back to source device
  29. AvoidCudaOOM = AvoidOOM(test=True)
  30. result = AvoidCudaOOM.retry_if_cuda_oom(torch.mm)(tensor,
  31. tensor.transpose(
  32. 1, 0))
  33. assert result.dtype == default_result.dtype and \
  34. result.device == default_result.device and \
  35. torch.allclose(default_result, result)
  36. # do not calculate on cpu and the outputs will be same as input
  37. AvoidCudaOOM = AvoidOOM(test=True, to_cpu=False)
  38. result = AvoidCudaOOM.retry_if_cuda_oom(torch.mm)(tensor,
  39. tensor.transpose(
  40. 1, 0))
  41. assert result.dtype == default_result.dtype and \
  42. result.device == default_result.device
  43. else:
  44. default_result = torch.mm(tensor, tensor.transpose(1, 0))
  45. AvoidCudaOOM = AvoidOOM()
  46. result = AvoidCudaOOM.retry_if_cuda_oom(torch.mm)(tensor,
  47. tensor.transpose(
  48. 1, 0))
  49. assert default_result.device == result.device and \
  50. default_result.dtype == result.dtype and \
  51. torch.equal(default_result, result)
  52. def test_cast_tensor_type():
  53. inputs = torch.rand(10)
  54. if torch.cuda.is_available():
  55. inputs = inputs.cuda()
  56. with pytest.raises(AssertionError):
  57. cast_tensor_type(inputs, src_type=None, dst_type=None)
  58. # input is a float
  59. out = cast_tensor_type(10., dst_type=torch.half)
  60. assert out == 10. and isinstance(out, float)
  61. # convert Tensor to fp16 and re-convert to fp32
  62. fp16_out = cast_tensor_type(inputs, dst_type=torch.half)
  63. assert fp16_out.dtype == torch.half
  64. fp32_out = cast_tensor_type(fp16_out, dst_type=torch.float32)
  65. assert fp32_out.dtype == torch.float32
  66. # input is a list
  67. list_input = [inputs, inputs]
  68. list_outs = cast_tensor_type(list_input, dst_type=torch.half)
  69. assert len(list_outs) == len(list_input) and \
  70. isinstance(list_outs, list)
  71. for out in list_outs:
  72. assert out.dtype == torch.half
  73. # input is a dict
  74. dict_input = {'test1': inputs, 'test2': inputs}
  75. dict_outs = cast_tensor_type(dict_input, dst_type=torch.half)
  76. assert len(dict_outs) == len(dict_input) and \
  77. isinstance(dict_outs, dict)
  78. # convert the input tensor to CPU and re-convert to GPU
  79. if torch.cuda.is_available():
  80. cpu_device = torch.empty(0).device
  81. gpu_device = inputs.device
  82. cpu_out = cast_tensor_type(inputs, dst_type=cpu_device)
  83. assert cpu_out.device == cpu_device
  84. gpu_out = cast_tensor_type(inputs, dst_type=gpu_device)
  85. assert gpu_out.device == gpu_device