memory.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. from collections import abc
  4. from contextlib import contextmanager
  5. from functools import wraps
  6. import torch
  7. from mmengine.logging import MMLogger
  8. def cast_tensor_type(inputs, src_type=None, dst_type=None):
  9. """Recursively convert Tensor in inputs from ``src_type`` to ``dst_type``.
  10. Args:
  11. inputs: Inputs that to be casted.
  12. src_type (torch.dtype | torch.device): Source type.
  13. src_type (torch.dtype | torch.device): Destination type.
  14. Returns:
  15. The same type with inputs, but all contained Tensors have been cast.
  16. """
  17. assert dst_type is not None
  18. if isinstance(inputs, torch.Tensor):
  19. if isinstance(dst_type, torch.device):
  20. # convert Tensor to dst_device
  21. if hasattr(inputs, 'to') and \
  22. hasattr(inputs, 'device') and \
  23. (inputs.device == src_type or src_type is None):
  24. return inputs.to(dst_type)
  25. else:
  26. return inputs
  27. else:
  28. # convert Tensor to dst_dtype
  29. if hasattr(inputs, 'to') and \
  30. hasattr(inputs, 'dtype') and \
  31. (inputs.dtype == src_type or src_type is None):
  32. return inputs.to(dst_type)
  33. else:
  34. return inputs
  35. # we need to ensure that the type of inputs to be casted are the same
  36. # as the argument `src_type`.
  37. elif isinstance(inputs, abc.Mapping):
  38. return type(inputs)({
  39. k: cast_tensor_type(v, src_type=src_type, dst_type=dst_type)
  40. for k, v in inputs.items()
  41. })
  42. elif isinstance(inputs, abc.Iterable):
  43. return type(inputs)(
  44. cast_tensor_type(item, src_type=src_type, dst_type=dst_type)
  45. for item in inputs)
  46. # TODO: Currently not supported
  47. # elif isinstance(inputs, InstanceData):
  48. # for key, value in inputs.items():
  49. # inputs[key] = cast_tensor_type(
  50. # value, src_type=src_type, dst_type=dst_type)
  51. # return inputs
  52. else:
  53. return inputs
  54. @contextmanager
  55. def _ignore_torch_cuda_oom():
  56. """A context which ignores CUDA OOM exception from pytorch.
  57. Code is modified from
  58. <https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/memory.py> # noqa: E501
  59. """
  60. try:
  61. yield
  62. except RuntimeError as e:
  63. # NOTE: the string may change?
  64. if 'CUDA out of memory. ' in str(e):
  65. pass
  66. else:
  67. raise
  68. class AvoidOOM:
  69. """Try to convert inputs to FP16 and CPU if got a PyTorch's CUDA Out of
  70. Memory error. It will do the following steps:
  71. 1. First retry after calling `torch.cuda.empty_cache()`.
  72. 2. If that still fails, it will then retry by converting inputs
  73. to FP16.
  74. 3. If that still fails trying to convert inputs to CPUs.
  75. In this case, it expects the function to dispatch to
  76. CPU implementation.
  77. Args:
  78. to_cpu (bool): Whether to convert outputs to CPU if get an OOM
  79. error. This will slow down the code significantly.
  80. Defaults to True.
  81. test (bool): Skip `_ignore_torch_cuda_oom` operate that can use
  82. lightweight data in unit test, only used in
  83. test unit. Defaults to False.
  84. Examples:
  85. >>> from mmdet.utils.memory import AvoidOOM
  86. >>> AvoidCUDAOOM = AvoidOOM()
  87. >>> output = AvoidOOM.retry_if_cuda_oom(
  88. >>> some_torch_function)(input1, input2)
  89. >>> # To use as a decorator
  90. >>> # from mmdet.utils import AvoidCUDAOOM
  91. >>> @AvoidCUDAOOM.retry_if_cuda_oom
  92. >>> def function(*args, **kwargs):
  93. >>> return None
  94. ```
  95. Note:
  96. 1. The output may be on CPU even if inputs are on GPU. Processing
  97. on CPU will slow down the code significantly.
  98. 2. When converting inputs to CPU, it will only look at each argument
  99. and check if it has `.device` and `.to` for conversion. Nested
  100. structures of tensors are not supported.
  101. 3. Since the function might be called more than once, it has to be
  102. stateless.
  103. """
  104. def __init__(self, to_cpu=True, test=False):
  105. self.to_cpu = to_cpu
  106. self.test = test
  107. def retry_if_cuda_oom(self, func):
  108. """Makes a function retry itself after encountering pytorch's CUDA OOM
  109. error.
  110. The implementation logic is referred to
  111. https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/memory.py
  112. Args:
  113. func: a stateless callable that takes tensor-like objects
  114. as arguments.
  115. Returns:
  116. func: a callable which retries `func` if OOM is encountered.
  117. """ # noqa: W605
  118. @wraps(func)
  119. def wrapped(*args, **kwargs):
  120. # raw function
  121. if not self.test:
  122. with _ignore_torch_cuda_oom():
  123. return func(*args, **kwargs)
  124. # Clear cache and retry
  125. torch.cuda.empty_cache()
  126. with _ignore_torch_cuda_oom():
  127. return func(*args, **kwargs)
  128. # get the type and device of first tensor
  129. dtype, device = None, None
  130. values = args + tuple(kwargs.values())
  131. for value in values:
  132. if isinstance(value, torch.Tensor):
  133. dtype = value.dtype
  134. device = value.device
  135. break
  136. if dtype is None or device is None:
  137. raise ValueError('There is no tensor in the inputs, '
  138. 'cannot get dtype and device.')
  139. # Convert to FP16
  140. fp16_args = cast_tensor_type(args, dst_type=torch.half)
  141. fp16_kwargs = cast_tensor_type(kwargs, dst_type=torch.half)
  142. logger = MMLogger.get_current_instance()
  143. logger.warning(f'Attempting to copy inputs of {str(func)} '
  144. 'to FP16 due to CUDA OOM')
  145. # get input tensor type, the output type will same as
  146. # the first parameter type.
  147. with _ignore_torch_cuda_oom():
  148. output = func(*fp16_args, **fp16_kwargs)
  149. output = cast_tensor_type(
  150. output, src_type=torch.half, dst_type=dtype)
  151. if not self.test:
  152. return output
  153. logger.warning('Using FP16 still meet CUDA OOM')
  154. # Try on CPU. This will slow down the code significantly,
  155. # therefore print a notice.
  156. if self.to_cpu:
  157. logger.warning(f'Attempting to copy inputs of {str(func)} '
  158. 'to CPU due to CUDA OOM')
  159. cpu_device = torch.empty(0).device
  160. cpu_args = cast_tensor_type(args, dst_type=cpu_device)
  161. cpu_kwargs = cast_tensor_type(kwargs, dst_type=cpu_device)
  162. # convert outputs to GPU
  163. with _ignore_torch_cuda_oom():
  164. logger.warning(f'Convert outputs to GPU (device={device})')
  165. output = func(*cpu_args, **cpu_kwargs)
  166. output = cast_tensor_type(
  167. output, src_type=cpu_device, dst_type=device)
  168. return output
  169. warnings.warn('Cannot convert output to GPU due to CUDA OOM, '
  170. 'the output is now on CPU, which might cause '
  171. 'errors if the output need to interact with GPU '
  172. 'data in subsequent operations')
  173. logger.warning('Cannot convert output to GPU due to '
  174. 'CUDA OOM, the output is on CPU now.')
  175. return func(*cpu_args, **cpu_kwargs)
  176. else:
  177. # may still get CUDA OOM error
  178. return func(*args, **kwargs)
  179. return wrapped
  180. # To use AvoidOOM as a decorator
  181. AvoidCUDAOOM = AvoidOOM()