profiling.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import contextlib
  3. import sys
  4. import time
  5. import torch
  6. if sys.version_info >= (3, 7):
  7. @contextlib.contextmanager
  8. def profile_time(trace_name,
  9. name,
  10. enabled=True,
  11. stream=None,
  12. end_stream=None):
  13. """Print time spent by CPU and GPU.
  14. Useful as a temporary context manager to find sweet spots of code
  15. suitable for async implementation.
  16. """
  17. if (not enabled) or not torch.cuda.is_available():
  18. yield
  19. return
  20. stream = stream if stream else torch.cuda.current_stream()
  21. end_stream = end_stream if end_stream else stream
  22. start = torch.cuda.Event(enable_timing=True)
  23. end = torch.cuda.Event(enable_timing=True)
  24. stream.record_event(start)
  25. try:
  26. cpu_start = time.monotonic()
  27. yield
  28. finally:
  29. cpu_end = time.monotonic()
  30. end_stream.record_event(end)
  31. end.synchronize()
  32. cpu_time = (cpu_end - cpu_start) * 1000
  33. gpu_time = start.elapsed_time(end)
  34. msg = f'{trace_name} {name} cpu_time {cpu_time:.2f} ms '
  35. msg += f'gpu_time {gpu_time:.2f} ms stream {stream}'
  36. print(msg, end_stream)