misc.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import glob
  3. import os
  4. import os.path as osp
  5. import warnings
  6. from typing import Union
  7. from mmengine.config import Config, ConfigDict
  8. from mmengine.logging import print_log
  9. def find_latest_checkpoint(path, suffix='pth'):
  10. """Find the latest checkpoint from the working directory.
  11. Args:
  12. path(str): The path to find checkpoints.
  13. suffix(str): File extension.
  14. Defaults to pth.
  15. Returns:
  16. latest_path(str | None): File path of the latest checkpoint.
  17. References:
  18. .. [1] https://github.com/microsoft/SoftTeacher
  19. /blob/main/ssod/utils/patch.py
  20. """
  21. if not osp.exists(path):
  22. warnings.warn('The path of checkpoints does not exist.')
  23. return None
  24. if osp.exists(osp.join(path, f'latest.{suffix}')):
  25. return osp.join(path, f'latest.{suffix}')
  26. checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
  27. if len(checkpoints) == 0:
  28. warnings.warn('There are no checkpoints in the path.')
  29. return None
  30. latest = -1
  31. latest_path = None
  32. for checkpoint in checkpoints:
  33. count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
  34. if count > latest:
  35. latest = count
  36. latest_path = checkpoint
  37. return latest_path
  38. def update_data_root(cfg, logger=None):
  39. """Update data root according to env MMDET_DATASETS.
  40. If set env MMDET_DATASETS, update cfg.data_root according to
  41. MMDET_DATASETS. Otherwise, using cfg.data_root as default.
  42. Args:
  43. cfg (:obj:`Config`): The model config need to modify
  44. logger (logging.Logger | str | None): the way to print msg
  45. """
  46. assert isinstance(cfg, Config), \
  47. f'cfg got wrong type: {type(cfg)}, expected mmengine.Config'
  48. if 'MMDET_DATASETS' in os.environ:
  49. dst_root = os.environ['MMDET_DATASETS']
  50. print_log(f'MMDET_DATASETS has been set to be {dst_root}.'
  51. f'Using {dst_root} as data root.')
  52. else:
  53. return
  54. assert isinstance(cfg, Config), \
  55. f'cfg got wrong type: {type(cfg)}, expected mmengine.Config'
  56. def update(cfg, src_str, dst_str):
  57. for k, v in cfg.items():
  58. if isinstance(v, ConfigDict):
  59. update(cfg[k], src_str, dst_str)
  60. if isinstance(v, str) and src_str in v:
  61. cfg[k] = v.replace(src_str, dst_str)
  62. update(cfg.data, cfg.data_root, dst_root)
  63. cfg.data_root = dst_root
  64. def get_test_pipeline_cfg(cfg: Union[str, ConfigDict]) -> ConfigDict:
  65. """Get the test dataset pipeline from entire config.
  66. Args:
  67. cfg (str or :obj:`ConfigDict`): the entire config. Can be a config
  68. file or a ``ConfigDict``.
  69. Returns:
  70. :obj:`ConfigDict`: the config of test dataset.
  71. """
  72. if isinstance(cfg, str):
  73. cfg = Config.fromfile(cfg)
  74. def _get_test_pipeline_cfg(dataset_cfg):
  75. if 'pipeline' in dataset_cfg:
  76. return dataset_cfg.pipeline
  77. # handle dataset wrapper
  78. elif 'dataset' in dataset_cfg:
  79. return _get_test_pipeline_cfg(dataset_cfg.dataset)
  80. # handle dataset wrappers like ConcatDataset
  81. elif 'datasets' in dataset_cfg:
  82. return _get_test_pipeline_cfg(dataset_cfg.datasets[0])
  83. raise RuntimeError('Cannot find `pipeline` in `test_dataloader`')
  84. return _get_test_pipeline_cfg(cfg.test_dataloader.dataset)