compat_config.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import warnings
  4. from mmengine.config import ConfigDict
  5. def compat_cfg(cfg):
  6. """This function would modify some filed to keep the compatibility of
  7. config.
  8. For example, it will move some args which will be deprecated to the correct
  9. fields.
  10. """
  11. cfg = copy.deepcopy(cfg)
  12. cfg = compat_imgs_per_gpu(cfg)
  13. cfg = compat_loader_args(cfg)
  14. cfg = compat_runner_args(cfg)
  15. return cfg
  16. def compat_runner_args(cfg):
  17. if 'runner' not in cfg:
  18. cfg.runner = ConfigDict({
  19. 'type': 'EpochBasedRunner',
  20. 'max_epochs': cfg.total_epochs
  21. })
  22. warnings.warn(
  23. 'config is now expected to have a `runner` section, '
  24. 'please set `runner` in your config.', UserWarning)
  25. else:
  26. if 'total_epochs' in cfg:
  27. assert cfg.total_epochs == cfg.runner.max_epochs
  28. return cfg
  29. def compat_imgs_per_gpu(cfg):
  30. cfg = copy.deepcopy(cfg)
  31. if 'imgs_per_gpu' in cfg.data:
  32. warnings.warn('"imgs_per_gpu" is deprecated in MMDet V2.0. '
  33. 'Please use "samples_per_gpu" instead')
  34. if 'samples_per_gpu' in cfg.data:
  35. warnings.warn(
  36. f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
  37. f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
  38. f'={cfg.data.imgs_per_gpu} is used in this experiments')
  39. else:
  40. warnings.warn('Automatically set "samples_per_gpu"="imgs_per_gpu"='
  41. f'{cfg.data.imgs_per_gpu} in this experiments')
  42. cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
  43. return cfg
  44. def compat_loader_args(cfg):
  45. """Deprecated sample_per_gpu in cfg.data."""
  46. cfg = copy.deepcopy(cfg)
  47. if 'train_dataloader' not in cfg.data:
  48. cfg.data['train_dataloader'] = ConfigDict()
  49. if 'val_dataloader' not in cfg.data:
  50. cfg.data['val_dataloader'] = ConfigDict()
  51. if 'test_dataloader' not in cfg.data:
  52. cfg.data['test_dataloader'] = ConfigDict()
  53. # special process for train_dataloader
  54. if 'samples_per_gpu' in cfg.data:
  55. samples_per_gpu = cfg.data.pop('samples_per_gpu')
  56. assert 'samples_per_gpu' not in \
  57. cfg.data.train_dataloader, ('`samples_per_gpu` are set '
  58. 'in `data` field and ` '
  59. 'data.train_dataloader` '
  60. 'at the same time. '
  61. 'Please only set it in '
  62. '`data.train_dataloader`. ')
  63. cfg.data.train_dataloader['samples_per_gpu'] = samples_per_gpu
  64. if 'persistent_workers' in cfg.data:
  65. persistent_workers = cfg.data.pop('persistent_workers')
  66. assert 'persistent_workers' not in \
  67. cfg.data.train_dataloader, ('`persistent_workers` are set '
  68. 'in `data` field and ` '
  69. 'data.train_dataloader` '
  70. 'at the same time. '
  71. 'Please only set it in '
  72. '`data.train_dataloader`. ')
  73. cfg.data.train_dataloader['persistent_workers'] = persistent_workers
  74. if 'workers_per_gpu' in cfg.data:
  75. workers_per_gpu = cfg.data.pop('workers_per_gpu')
  76. cfg.data.train_dataloader['workers_per_gpu'] = workers_per_gpu
  77. cfg.data.val_dataloader['workers_per_gpu'] = workers_per_gpu
  78. cfg.data.test_dataloader['workers_per_gpu'] = workers_per_gpu
  79. # special process for val_dataloader
  80. if 'samples_per_gpu' in cfg.data.val:
  81. # keep default value of `sample_per_gpu` is 1
  82. assert 'samples_per_gpu' not in \
  83. cfg.data.val_dataloader, ('`samples_per_gpu` are set '
  84. 'in `data.val` field and ` '
  85. 'data.val_dataloader` at '
  86. 'the same time. '
  87. 'Please only set it in '
  88. '`data.val_dataloader`. ')
  89. cfg.data.val_dataloader['samples_per_gpu'] = \
  90. cfg.data.val.pop('samples_per_gpu')
  91. # special process for val_dataloader
  92. # in case the test dataset is concatenated
  93. if isinstance(cfg.data.test, dict):
  94. if 'samples_per_gpu' in cfg.data.test:
  95. assert 'samples_per_gpu' not in \
  96. cfg.data.test_dataloader, ('`samples_per_gpu` are set '
  97. 'in `data.test` field and ` '
  98. 'data.test_dataloader` '
  99. 'at the same time. '
  100. 'Please only set it in '
  101. '`data.test_dataloader`. ')
  102. cfg.data.test_dataloader['samples_per_gpu'] = \
  103. cfg.data.test.pop('samples_per_gpu')
  104. elif isinstance(cfg.data.test, list):
  105. for ds_cfg in cfg.data.test:
  106. if 'samples_per_gpu' in ds_cfg:
  107. assert 'samples_per_gpu' not in \
  108. cfg.data.test_dataloader, ('`samples_per_gpu` are set '
  109. 'in `data.test` field and ` '
  110. 'data.test_dataloader` at'
  111. ' the same time. '
  112. 'Please only set it in '
  113. '`data.test_dataloader`. ')
  114. samples_per_gpu = max(
  115. [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
  116. cfg.data.test_dataloader['samples_per_gpu'] = samples_per_gpu
  117. return cfg