12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import re
- from mmengine.config import Config
- def replace_cfg_vals(ori_cfg):
- """Replace the string "${key}" with the corresponding value.
- Replace the "${key}" with the value of ori_cfg.key in the config. And
- support replacing the chained ${key}. Such as, replace "${key0.key1}"
- with the value of cfg.key0.key1. Code is modified from `vars.py
- < https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/vars.py>`_ # noqa: E501
- Args:
- ori_cfg (mmengine.config.Config):
- The origin config with "${key}" generated from a file.
- Returns:
- updated_cfg [mmengine.config.Config]:
- The config with "${key}" replaced by the corresponding value.
- """
- def get_value(cfg, key):
- for k in key.split('.'):
- cfg = cfg[k]
- return cfg
- def replace_value(cfg):
- if isinstance(cfg, dict):
- return {key: replace_value(value) for key, value in cfg.items()}
- elif isinstance(cfg, list):
- return [replace_value(item) for item in cfg]
- elif isinstance(cfg, tuple):
- return tuple([replace_value(item) for item in cfg])
- elif isinstance(cfg, str):
- # the format of string cfg may be:
- # 1) "${key}", which will be replaced with cfg.key directly
- # 2) "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx",
- # which will be replaced with the string of the cfg.key
- keys = pattern_key.findall(cfg)
- values = [get_value(ori_cfg, key[2:-1]) for key in keys]
- if len(keys) == 1 and keys[0] == cfg:
- # the format of string cfg is "${key}"
- cfg = values[0]
- else:
- for key, value in zip(keys, values):
- # the format of string cfg is
- # "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx"
- assert not isinstance(value, (dict, list, tuple)), \
- f'for the format of string cfg is ' \
- f"'xxxxx${key}xxxxx' or 'xxx${key}xxx${key}xxx', " \
- f"the type of the value of '${key}' " \
- f'can not be dict, list, or tuple' \
- f'but you input {type(value)} in {cfg}'
- cfg = cfg.replace(key, str(value))
- return cfg
- else:
- return cfg
- # the pattern of string "${key}"
- pattern_key = re.compile(r'\$\{[a-zA-Z\d_.]*\}')
- # the type of ori_cfg._cfg_dict is mmengine.config.ConfigDict
- updated_cfg = Config(
- replace_value(ori_cfg._cfg_dict), filename=ori_cfg.filename)
- # replace the model with model_wrapper
- if updated_cfg.get('model_wrapper', None) is not None:
- updated_cfg.model = updated_cfg.model_wrapper
- updated_cfg.pop('model_wrapper')
- return updated_cfg
|