setup_env.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import datetime
  3. import logging
  4. import os
  5. import platform
  6. import warnings
  7. import cv2
  8. import torch.multiprocessing as mp
  9. from mmengine import DefaultScope
  10. from mmengine.logging import print_log
  11. from mmengine.utils import digit_version
  12. def setup_cache_size_limit_of_dynamo():
  13. """Setup cache size limit of dynamo.
  14. Note: Due to the dynamic shape of the loss calculation and
  15. post-processing parts in the object detection algorithm, these
  16. functions must be compiled every time they are run.
  17. Setting a large value for torch._dynamo.config.cache_size_limit
  18. may result in repeated compilation, which can slow down training
  19. and testing speed. Therefore, we need to set the default value of
  20. cache_size_limit smaller. An empirical value is 4.
  21. """
  22. import torch
  23. if digit_version(torch.__version__) >= digit_version('2.0.0'):
  24. if 'DYNAMO_CACHE_SIZE_LIMIT' in os.environ:
  25. import torch._dynamo
  26. cache_size_limit = int(os.environ['DYNAMO_CACHE_SIZE_LIMIT'])
  27. torch._dynamo.config.cache_size_limit = cache_size_limit
  28. print_log(
  29. f'torch._dynamo.config.cache_size_limit is force '
  30. f'set to {cache_size_limit}.',
  31. logger='current',
  32. level=logging.WARNING)
  33. def setup_multi_processes(cfg):
  34. """Setup multi-processing environment variables."""
  35. # set multi-process start method as `fork` to speed up the training
  36. if platform.system() != 'Windows':
  37. mp_start_method = cfg.get('mp_start_method', 'fork')
  38. current_method = mp.get_start_method(allow_none=True)
  39. if current_method is not None and current_method != mp_start_method:
  40. warnings.warn(
  41. f'Multi-processing start method `{mp_start_method}` is '
  42. f'different from the previous setting `{current_method}`.'
  43. f'It will be force set to `{mp_start_method}`. You can change '
  44. f'this behavior by changing `mp_start_method` in your config.')
  45. mp.set_start_method(mp_start_method, force=True)
  46. # disable opencv multithreading to avoid system being overloaded
  47. opencv_num_threads = cfg.get('opencv_num_threads', 0)
  48. cv2.setNumThreads(opencv_num_threads)
  49. # setup OMP threads
  50. # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
  51. workers_per_gpu = cfg.data.get('workers_per_gpu', 1)
  52. if 'train_dataloader' in cfg.data:
  53. workers_per_gpu = \
  54. max(cfg.data.train_dataloader.get('workers_per_gpu', 1),
  55. workers_per_gpu)
  56. if 'OMP_NUM_THREADS' not in os.environ and workers_per_gpu > 1:
  57. omp_num_threads = 1
  58. warnings.warn(
  59. f'Setting OMP_NUM_THREADS environment variable for each process '
  60. f'to be {omp_num_threads} in default, to avoid your system being '
  61. f'overloaded, please further tune the variable for optimal '
  62. f'performance in your application as needed.')
  63. os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
  64. # setup MKL threads
  65. if 'MKL_NUM_THREADS' not in os.environ and workers_per_gpu > 1:
  66. mkl_num_threads = 1
  67. warnings.warn(
  68. f'Setting MKL_NUM_THREADS environment variable for each process '
  69. f'to be {mkl_num_threads} in default, to avoid your system being '
  70. f'overloaded, please further tune the variable for optimal '
  71. f'performance in your application as needed.')
  72. os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
  73. def register_all_modules(init_default_scope: bool = True) -> None:
  74. """Register all modules in mmdet into the registries.
  75. Args:
  76. init_default_scope (bool): Whether initialize the mmdet default scope.
  77. When `init_default_scope=True`, the global default scope will be
  78. set to `mmdet`, and all registries will build modules from mmdet's
  79. registry node. To understand more about the registry, please refer
  80. to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md
  81. Defaults to True.
  82. """ # noqa
  83. import mmdet.datasets # noqa: F401,F403
  84. import mmdet.engine # noqa: F401,F403
  85. import mmdet.evaluation # noqa: F401,F403
  86. import mmdet.models # noqa: F401,F403
  87. import mmdet.visualization # noqa: F401,F403
  88. if init_default_scope:
  89. never_created = DefaultScope.get_current_instance() is None \
  90. or not DefaultScope.check_instance_created('mmdet')
  91. if never_created:
  92. DefaultScope.get_instance('mmdet', scope_name='mmdet')
  93. return
  94. current_scope = DefaultScope.get_current_instance()
  95. if current_scope.scope_name != 'mmdet':
  96. warnings.warn('The current default scope '
  97. f'"{current_scope.scope_name}" is not "mmdet", '
  98. '`register_all_modules` will force the current'
  99. 'default scope to be "mmdet". If this is not '
  100. 'expected, please set `init_default_scope=False`.')
  101. # avoid name conflict
  102. new_instance_name = f'mmdet-{datetime.datetime.now()}'
  103. DefaultScope.get_instance(new_instance_name, scope_name='mmdet')