gather_train_benchmark_metric.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import glob
  4. import os.path as osp
  5. from gather_models import get_final_results
  6. from mmengine.config import Config
  7. from mmengine.fileio import dump
  8. from mmengine.utils import mkdir_or_exist
  9. try:
  10. import xlrd
  11. except ImportError:
  12. xlrd = None
  13. try:
  14. import xlutils
  15. from xlutils.copy import copy
  16. except ImportError:
  17. xlutils = None
  18. def parse_args():
  19. parser = argparse.ArgumentParser(
  20. description='Gather benchmarked models metric')
  21. parser.add_argument(
  22. 'root',
  23. type=str,
  24. help='root path of benchmarked models to be gathered')
  25. parser.add_argument(
  26. 'txt_path', type=str, help='txt path output by benchmark_filter')
  27. parser.add_argument(
  28. '--out', type=str, help='output path of gathered metrics to be stored')
  29. parser.add_argument(
  30. '--not-show', action='store_true', help='not show metrics')
  31. parser.add_argument(
  32. '--excel', type=str, help='input path of excel to be recorded')
  33. parser.add_argument(
  34. '--ncol', type=int, help='Number of column to be modified or appended')
  35. args = parser.parse_args()
  36. return args
  37. if __name__ == '__main__':
  38. args = parse_args()
  39. if args.excel:
  40. assert args.ncol, 'Please specify "--excel" and "--ncol" ' \
  41. 'at the same time'
  42. if xlrd is None:
  43. raise RuntimeError(
  44. 'xlrd is not installed,'
  45. 'Please use “pip install xlrd==1.2.0” to install')
  46. if xlutils is None:
  47. raise RuntimeError(
  48. 'xlutils is not installed,'
  49. 'Please use “pip install xlutils==2.0.0” to install')
  50. readbook = xlrd.open_workbook(args.excel)
  51. sheet = readbook.sheet_by_name('Sheet1')
  52. sheet_info = {}
  53. total_nrows = sheet.nrows
  54. for i in range(3, sheet.nrows):
  55. sheet_info[sheet.row_values(i)[0]] = i
  56. xlrw = copy(readbook)
  57. table = xlrw.get_sheet(0)
  58. root_path = args.root
  59. metrics_out = args.out
  60. result_dict = {}
  61. with open(args.txt_path, 'r') as f:
  62. model_cfgs = f.readlines()
  63. for i, config in enumerate(model_cfgs):
  64. config = config.strip()
  65. if len(config) == 0:
  66. continue
  67. config_name = osp.split(config)[-1]
  68. config_name = osp.splitext(config_name)[0]
  69. result_path = osp.join(root_path, config_name)
  70. if osp.exists(result_path):
  71. # 1 read config
  72. cfg = Config.fromfile(config)
  73. total_epochs = cfg.runner.max_epochs
  74. final_results = cfg.evaluation.metric
  75. if not isinstance(final_results, list):
  76. final_results = [final_results]
  77. final_results_out = []
  78. for key in final_results:
  79. if 'proposal_fast' in key:
  80. final_results_out.append('AR@1000') # RPN
  81. elif 'mAP' not in key:
  82. final_results_out.append(key + '_mAP')
  83. # 2 determine whether total_epochs ckpt exists
  84. ckpt_path = f'epoch_{total_epochs}.pth'
  85. if osp.exists(osp.join(result_path, ckpt_path)):
  86. log_json_path = list(
  87. sorted(glob.glob(osp.join(result_path,
  88. '*.log.json'))))[-1]
  89. # 3 read metric
  90. model_performance = get_final_results(
  91. log_json_path, total_epochs, final_results_out)
  92. if model_performance is None:
  93. print(f'log file error: {log_json_path}')
  94. continue
  95. for performance in model_performance:
  96. if performance in ['AR@1000', 'bbox_mAP', 'segm_mAP']:
  97. metric = round(
  98. model_performance[performance] * 100, 1)
  99. model_performance[performance] = metric
  100. result_dict[config] = model_performance
  101. # update and append excel content
  102. if args.excel:
  103. if 'AR@1000' in model_performance:
  104. metrics = f'{model_performance["AR@1000"]}' \
  105. f'(AR@1000)'
  106. elif 'segm_mAP' in model_performance:
  107. metrics = f'{model_performance["bbox_mAP"]}/' \
  108. f'{model_performance["segm_mAP"]}'
  109. else:
  110. metrics = f'{model_performance["bbox_mAP"]}'
  111. row_num = sheet_info.get(config, None)
  112. if row_num:
  113. table.write(row_num, args.ncol, metrics)
  114. else:
  115. table.write(total_nrows, 0, config)
  116. table.write(total_nrows, args.ncol, metrics)
  117. total_nrows += 1
  118. else:
  119. print(f'{config} not exist: {ckpt_path}')
  120. else:
  121. print(f'not exist: {config}')
  122. # 4 save or print results
  123. if metrics_out:
  124. mkdir_or_exist(metrics_out)
  125. dump(result_dict, osp.join(metrics_out, 'model_metric_info.json'))
  126. if not args.not_show:
  127. print('===================================')
  128. for config_name, metrics in result_dict.items():
  129. print(config_name, metrics)
  130. print('===================================')
  131. if args.excel:
  132. filename, sufflx = osp.splitext(args.excel)
  133. xlrw.save(f'{filename}_o{sufflx}')
  134. print(f'>>> Output {filename}_o{sufflx}')