replace_cfg_vals.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import re
  3. from mmengine.config import Config
  4. def replace_cfg_vals(ori_cfg):
  5. """Replace the string "${key}" with the corresponding value.
  6. Replace the "${key}" with the value of ori_cfg.key in the config. And
  7. support replacing the chained ${key}. Such as, replace "${key0.key1}"
  8. with the value of cfg.key0.key1. Code is modified from `vars.py
  9. < https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/vars.py>`_ # noqa: E501
  10. Args:
  11. ori_cfg (mmengine.config.Config):
  12. The origin config with "${key}" generated from a file.
  13. Returns:
  14. updated_cfg [mmengine.config.Config]:
  15. The config with "${key}" replaced by the corresponding value.
  16. """
  17. def get_value(cfg, key):
  18. for k in key.split('.'):
  19. cfg = cfg[k]
  20. return cfg
  21. def replace_value(cfg):
  22. if isinstance(cfg, dict):
  23. return {key: replace_value(value) for key, value in cfg.items()}
  24. elif isinstance(cfg, list):
  25. return [replace_value(item) for item in cfg]
  26. elif isinstance(cfg, tuple):
  27. return tuple([replace_value(item) for item in cfg])
  28. elif isinstance(cfg, str):
  29. # the format of string cfg may be:
  30. # 1) "${key}", which will be replaced with cfg.key directly
  31. # 2) "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx",
  32. # which will be replaced with the string of the cfg.key
  33. keys = pattern_key.findall(cfg)
  34. values = [get_value(ori_cfg, key[2:-1]) for key in keys]
  35. if len(keys) == 1 and keys[0] == cfg:
  36. # the format of string cfg is "${key}"
  37. cfg = values[0]
  38. else:
  39. for key, value in zip(keys, values):
  40. # the format of string cfg is
  41. # "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx"
  42. assert not isinstance(value, (dict, list, tuple)), \
  43. f'for the format of string cfg is ' \
  44. f"'xxxxx${key}xxxxx' or 'xxx${key}xxx${key}xxx', " \
  45. f"the type of the value of '${key}' " \
  46. f'can not be dict, list, or tuple' \
  47. f'but you input {type(value)} in {cfg}'
  48. cfg = cfg.replace(key, str(value))
  49. return cfg
  50. else:
  51. return cfg
  52. # the pattern of string "${key}"
  53. pattern_key = re.compile(r'\$\{[a-zA-Z\d_.]*\}')
  54. # the type of ori_cfg._cfg_dict is mmengine.config.ConfigDict
  55. updated_cfg = Config(
  56. replace_value(ori_cfg._cfg_dict), filename=ori_cfg.filename)
  57. # replace the model with model_wrapper
  58. if updated_cfg.get('model_wrapper', None) is not None:
  59. updated_cfg.model = updated_cfg.model_wrapper
  60. updated_cfg.pop('model_wrapper')
  61. return updated_cfg