contextmanagers.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import asyncio
  3. import contextlib
  4. import logging
  5. import os
  6. import time
  7. from typing import List
  8. import torch
  9. logger = logging.getLogger(__name__)
  10. DEBUG_COMPLETED_TIME = bool(os.environ.get('DEBUG_COMPLETED_TIME', False))
  11. @contextlib.asynccontextmanager
  12. async def completed(trace_name='',
  13. name='',
  14. sleep_interval=0.05,
  15. streams: List[torch.cuda.Stream] = None):
  16. """Async context manager that waits for work to complete on given CUDA
  17. streams."""
  18. if not torch.cuda.is_available():
  19. yield
  20. return
  21. stream_before_context_switch = torch.cuda.current_stream()
  22. if not streams:
  23. streams = [stream_before_context_switch]
  24. else:
  25. streams = [s if s else stream_before_context_switch for s in streams]
  26. end_events = [
  27. torch.cuda.Event(enable_timing=DEBUG_COMPLETED_TIME) for _ in streams
  28. ]
  29. if DEBUG_COMPLETED_TIME:
  30. start = torch.cuda.Event(enable_timing=True)
  31. stream_before_context_switch.record_event(start)
  32. cpu_start = time.monotonic()
  33. logger.debug('%s %s starting, streams: %s', trace_name, name, streams)
  34. grad_enabled_before = torch.is_grad_enabled()
  35. try:
  36. yield
  37. finally:
  38. current_stream = torch.cuda.current_stream()
  39. assert current_stream == stream_before_context_switch
  40. if DEBUG_COMPLETED_TIME:
  41. cpu_end = time.monotonic()
  42. for i, stream in enumerate(streams):
  43. event = end_events[i]
  44. stream.record_event(event)
  45. grad_enabled_after = torch.is_grad_enabled()
  46. # observed change of torch.is_grad_enabled() during concurrent run of
  47. # async_test_bboxes code
  48. assert (grad_enabled_before == grad_enabled_after
  49. ), 'Unexpected is_grad_enabled() value change'
  50. are_done = [e.query() for e in end_events]
  51. logger.debug('%s %s completed: %s streams: %s', trace_name, name,
  52. are_done, streams)
  53. with torch.cuda.stream(stream_before_context_switch):
  54. while not all(are_done):
  55. await asyncio.sleep(sleep_interval)
  56. are_done = [e.query() for e in end_events]
  57. logger.debug(
  58. '%s %s completed: %s streams: %s',
  59. trace_name,
  60. name,
  61. are_done,
  62. streams,
  63. )
  64. current_stream = torch.cuda.current_stream()
  65. assert current_stream == stream_before_context_switch
  66. if DEBUG_COMPLETED_TIME:
  67. cpu_time = (cpu_end - cpu_start) * 1000
  68. stream_times_ms = ''
  69. for i, stream in enumerate(streams):
  70. elapsed_time = start.elapsed_time(end_events[i])
  71. stream_times_ms += f' {stream} {elapsed_time:.2f} ms'
  72. logger.info('%s %s %.2f ms %s', trace_name, name, cpu_time,
  73. stream_times_ms)
  74. @contextlib.asynccontextmanager
  75. async def concurrent(streamqueue: asyncio.Queue,
  76. trace_name='concurrent',
  77. name='stream'):
  78. """Run code concurrently in different streams.
  79. :param streamqueue: asyncio.Queue instance.
  80. Queue tasks define the pool of streams used for concurrent execution.
  81. """
  82. if not torch.cuda.is_available():
  83. yield
  84. return
  85. initial_stream = torch.cuda.current_stream()
  86. with torch.cuda.stream(initial_stream):
  87. stream = await streamqueue.get()
  88. assert isinstance(stream, torch.cuda.Stream)
  89. try:
  90. with torch.cuda.stream(stream):
  91. logger.debug('%s %s is starting, stream: %s', trace_name, name,
  92. stream)
  93. yield
  94. current = torch.cuda.current_stream()
  95. assert current == stream
  96. logger.debug('%s %s has finished, stream: %s', trace_name,
  97. name, stream)
  98. finally:
  99. streamqueue.task_done()
  100. streamqueue.put_nowait(stream)