memory_profiler_hook.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Sequence
  3. from mmengine.hooks import Hook
  4. from mmengine.runner import Runner
  5. from mmdet.registry import HOOKS
  6. from mmdet.structures import DetDataSample
  7. @HOOKS.register_module()
  8. class MemoryProfilerHook(Hook):
  9. """Memory profiler hook recording memory information including virtual
  10. memory, swap memory, and the memory of the current process.
  11. Args:
  12. interval (int): Checking interval (every k iterations).
  13. Default: 50.
  14. """
  15. def __init__(self, interval: int = 50) -> None:
  16. try:
  17. from psutil import swap_memory, virtual_memory
  18. self._swap_memory = swap_memory
  19. self._virtual_memory = virtual_memory
  20. except ImportError:
  21. raise ImportError('psutil is not installed, please install it by: '
  22. 'pip install psutil')
  23. try:
  24. from memory_profiler import memory_usage
  25. self._memory_usage = memory_usage
  26. except ImportError:
  27. raise ImportError(
  28. 'memory_profiler is not installed, please install it by: '
  29. 'pip install memory_profiler')
  30. self.interval = interval
  31. def _record_memory_information(self, runner: Runner) -> None:
  32. """Regularly record memory information.
  33. Args:
  34. runner (:obj:`Runner`): The runner of the training or evaluation
  35. process.
  36. """
  37. # in Byte
  38. virtual_memory = self._virtual_memory()
  39. swap_memory = self._swap_memory()
  40. # in MB
  41. process_memory = self._memory_usage()[0]
  42. factor = 1024 * 1024
  43. runner.logger.info(
  44. 'Memory information '
  45. 'available_memory: '
  46. f'{round(virtual_memory.available / factor)} MB, '
  47. 'used_memory: '
  48. f'{round(virtual_memory.used / factor)} MB, '
  49. f'memory_utilization: {virtual_memory.percent} %, '
  50. 'available_swap_memory: '
  51. f'{round((swap_memory.total - swap_memory.used) / factor)}'
  52. ' MB, '
  53. f'used_swap_memory: {round(swap_memory.used / factor)} MB, '
  54. f'swap_memory_utilization: {swap_memory.percent} %, '
  55. 'current_process_memory: '
  56. f'{round(process_memory)} MB')
  57. def after_train_iter(self,
  58. runner: Runner,
  59. batch_idx: int,
  60. data_batch: Optional[dict] = None,
  61. outputs: Optional[dict] = None) -> None:
  62. """Regularly record memory information.
  63. Args:
  64. runner (:obj:`Runner`): The runner of the training process.
  65. batch_idx (int): The index of the current batch in the train loop.
  66. data_batch (dict, optional): Data from dataloader.
  67. Defaults to None.
  68. outputs (dict, optional): Outputs from model. Defaults to None.
  69. """
  70. if self.every_n_inner_iters(batch_idx, self.interval):
  71. self._record_memory_information(runner)
  72. def after_val_iter(
  73. self,
  74. runner: Runner,
  75. batch_idx: int,
  76. data_batch: Optional[dict] = None,
  77. outputs: Optional[Sequence[DetDataSample]] = None) -> None:
  78. """Regularly record memory information.
  79. Args:
  80. runner (:obj:`Runner`): The runner of the validation process.
  81. batch_idx (int): The index of the current batch in the val loop.
  82. data_batch (dict, optional): Data from dataloader.
  83. Defaults to None.
  84. outputs (Sequence[:obj:`DetDataSample`], optional):
  85. Outputs from model. Defaults to None.
  86. """
  87. if self.every_n_inner_iters(batch_idx, self.interval):
  88. self._record_memory_information(runner)
  89. def after_test_iter(
  90. self,
  91. runner: Runner,
  92. batch_idx: int,
  93. data_batch: Optional[dict] = None,
  94. outputs: Optional[Sequence[DetDataSample]] = None) -> None:
  95. """Regularly record memory information.
  96. Args:
  97. runner (:obj:`Runner`): The runner of the testing process.
  98. batch_idx (int): The index of the current batch in the test loop.
  99. data_batch (dict, optional): Data from dataloader.
  100. Defaults to None.
  101. outputs (Sequence[:obj:`DetDataSample`], optional):
  102. Outputs from model. Defaults to None.
  103. """
  104. if self.every_n_inner_iters(batch_idx, self.interval):
  105. self._record_memory_information(runner)