convert_test_benchmark_script.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os
  4. import os.path as osp
  5. from mmengine import Config
  6. def parse_args():
  7. parser = argparse.ArgumentParser(
  8. description='Convert benchmark model list to script')
  9. parser.add_argument('config', help='test config file path')
  10. parser.add_argument('--port', type=int, default=29666, help='dist port')
  11. parser.add_argument(
  12. '--run', action='store_true', help='run script directly')
  13. parser.add_argument(
  14. '--out', type=str, help='path to save model benchmark script')
  15. args = parser.parse_args()
  16. return args
  17. def process_model_info(model_info, work_dir):
  18. config = model_info['config'].strip()
  19. fname, _ = osp.splitext(osp.basename(config))
  20. job_name = fname
  21. work_dir = '$WORK_DIR/' + fname
  22. checkpoint = model_info['checkpoint'].strip()
  23. return dict(
  24. config=config,
  25. job_name=job_name,
  26. work_dir=work_dir,
  27. checkpoint=checkpoint)
  28. def create_test_bash_info(commands, model_test_dict, port, script_name,
  29. partition):
  30. config = model_test_dict['config']
  31. job_name = model_test_dict['job_name']
  32. checkpoint = model_test_dict['checkpoint']
  33. work_dir = model_test_dict['work_dir']
  34. echo_info = f' \necho \'{config}\' &'
  35. commands.append(echo_info)
  36. commands.append('\n')
  37. command_info = f'GPUS=8 GPUS_PER_NODE=8 ' \
  38. f'CPUS_PER_TASK=$CPUS_PRE_TASK {script_name} '
  39. command_info += f'{partition} '
  40. command_info += f'{job_name} '
  41. command_info += f'{config} '
  42. command_info += f'$CHECKPOINT_DIR/{checkpoint} '
  43. command_info += f'--work-dir {work_dir} '
  44. command_info += f'--cfg-option env_cfg.dist_cfg.port={port} '
  45. command_info += ' &'
  46. commands.append(command_info)
  47. def main():
  48. args = parse_args()
  49. if args.out:
  50. out_suffix = args.out.split('.')[-1]
  51. assert args.out.endswith('.sh'), \
  52. f'Expected out file path suffix is .sh, but get .{out_suffix}'
  53. assert args.out or args.run, \
  54. ('Please specify at least one operation (save/run/ the '
  55. 'script) with the argument "--out" or "--run"')
  56. commands = []
  57. partition_name = 'PARTITION=$1 '
  58. commands.append(partition_name)
  59. commands.append('\n')
  60. checkpoint_root = 'CHECKPOINT_DIR=$2 '
  61. commands.append(checkpoint_root)
  62. commands.append('\n')
  63. work_dir = 'WORK_DIR=$3 '
  64. commands.append(work_dir)
  65. commands.append('\n')
  66. cpus_pre_task = 'CPUS_PER_TASK=${4:-2} '
  67. commands.append(cpus_pre_task)
  68. commands.append('\n')
  69. script_name = osp.join('tools', 'slurm_test.sh')
  70. port = args.port
  71. cfg = Config.fromfile(args.config)
  72. for model_key in cfg:
  73. model_infos = cfg[model_key]
  74. if not isinstance(model_infos, list):
  75. model_infos = [model_infos]
  76. for model_info in model_infos:
  77. print('processing: ', model_info['config'])
  78. model_test_dict = process_model_info(model_info, work_dir)
  79. create_test_bash_info(commands, model_test_dict, port, script_name,
  80. '$PARTITION')
  81. port += 1
  82. command_str = ''.join(commands)
  83. if args.out:
  84. with open(args.out, 'w') as f:
  85. f.write(command_str)
  86. if args.run:
  87. os.system(command_str)
  88. if __name__ == '__main__':
  89. main()