123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import copy
- import warnings
- from mmengine.config import ConfigDict
- def compat_cfg(cfg):
- """This function would modify some filed to keep the compatibility of
- config.
- For example, it will move some args which will be deprecated to the correct
- fields.
- """
- cfg = copy.deepcopy(cfg)
- cfg = compat_imgs_per_gpu(cfg)
- cfg = compat_loader_args(cfg)
- cfg = compat_runner_args(cfg)
- return cfg
- def compat_runner_args(cfg):
- if 'runner' not in cfg:
- cfg.runner = ConfigDict({
- 'type': 'EpochBasedRunner',
- 'max_epochs': cfg.total_epochs
- })
- warnings.warn(
- 'config is now expected to have a `runner` section, '
- 'please set `runner` in your config.', UserWarning)
- else:
- if 'total_epochs' in cfg:
- assert cfg.total_epochs == cfg.runner.max_epochs
- return cfg
- def compat_imgs_per_gpu(cfg):
- cfg = copy.deepcopy(cfg)
- if 'imgs_per_gpu' in cfg.data:
- warnings.warn('"imgs_per_gpu" is deprecated in MMDet V2.0. '
- 'Please use "samples_per_gpu" instead')
- if 'samples_per_gpu' in cfg.data:
- warnings.warn(
- f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
- f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
- f'={cfg.data.imgs_per_gpu} is used in this experiments')
- else:
- warnings.warn('Automatically set "samples_per_gpu"="imgs_per_gpu"='
- f'{cfg.data.imgs_per_gpu} in this experiments')
- cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
- return cfg
- def compat_loader_args(cfg):
- """Deprecated sample_per_gpu in cfg.data."""
- cfg = copy.deepcopy(cfg)
- if 'train_dataloader' not in cfg.data:
- cfg.data['train_dataloader'] = ConfigDict()
- if 'val_dataloader' not in cfg.data:
- cfg.data['val_dataloader'] = ConfigDict()
- if 'test_dataloader' not in cfg.data:
- cfg.data['test_dataloader'] = ConfigDict()
- # special process for train_dataloader
- if 'samples_per_gpu' in cfg.data:
- samples_per_gpu = cfg.data.pop('samples_per_gpu')
- assert 'samples_per_gpu' not in \
- cfg.data.train_dataloader, ('`samples_per_gpu` are set '
- 'in `data` field and ` '
- 'data.train_dataloader` '
- 'at the same time. '
- 'Please only set it in '
- '`data.train_dataloader`. ')
- cfg.data.train_dataloader['samples_per_gpu'] = samples_per_gpu
- if 'persistent_workers' in cfg.data:
- persistent_workers = cfg.data.pop('persistent_workers')
- assert 'persistent_workers' not in \
- cfg.data.train_dataloader, ('`persistent_workers` are set '
- 'in `data` field and ` '
- 'data.train_dataloader` '
- 'at the same time. '
- 'Please only set it in '
- '`data.train_dataloader`. ')
- cfg.data.train_dataloader['persistent_workers'] = persistent_workers
- if 'workers_per_gpu' in cfg.data:
- workers_per_gpu = cfg.data.pop('workers_per_gpu')
- cfg.data.train_dataloader['workers_per_gpu'] = workers_per_gpu
- cfg.data.val_dataloader['workers_per_gpu'] = workers_per_gpu
- cfg.data.test_dataloader['workers_per_gpu'] = workers_per_gpu
- # special process for val_dataloader
- if 'samples_per_gpu' in cfg.data.val:
- # keep default value of `sample_per_gpu` is 1
- assert 'samples_per_gpu' not in \
- cfg.data.val_dataloader, ('`samples_per_gpu` are set '
- 'in `data.val` field and ` '
- 'data.val_dataloader` at '
- 'the same time. '
- 'Please only set it in '
- '`data.val_dataloader`. ')
- cfg.data.val_dataloader['samples_per_gpu'] = \
- cfg.data.val.pop('samples_per_gpu')
- # special process for val_dataloader
- # in case the test dataset is concatenated
- if isinstance(cfg.data.test, dict):
- if 'samples_per_gpu' in cfg.data.test:
- assert 'samples_per_gpu' not in \
- cfg.data.test_dataloader, ('`samples_per_gpu` are set '
- 'in `data.test` field and ` '
- 'data.test_dataloader` '
- 'at the same time. '
- 'Please only set it in '
- '`data.test_dataloader`. ')
- cfg.data.test_dataloader['samples_per_gpu'] = \
- cfg.data.test.pop('samples_per_gpu')
- elif isinstance(cfg.data.test, list):
- for ds_cfg in cfg.data.test:
- if 'samples_per_gpu' in ds_cfg:
- assert 'samples_per_gpu' not in \
- cfg.data.test_dataloader, ('`samples_per_gpu` are set '
- 'in `data.test` field and ` '
- 'data.test_dataloader` at'
- ' the same time. '
- 'Please only set it in '
- '`data.test_dataloader`. ')
- samples_per_gpu = max(
- [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
- cfg.data.test_dataloader['samples_per_gpu'] = samples_per_gpu
- return cfg
|