gather_test_benchmark_metric.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import glob
  4. import os.path as osp
  5. from mmengine.config import Config
  6. from mmengine.fileio import dump, load
  7. from mmengine.utils import mkdir_or_exist
  8. def parse_args():
  9. parser = argparse.ArgumentParser(
  10. description='Gather benchmarked models metric')
  11. parser.add_argument('config', help='test config file path')
  12. parser.add_argument(
  13. 'root',
  14. type=str,
  15. help='root path of benchmarked models to be gathered')
  16. parser.add_argument(
  17. '--out', type=str, help='output path of gathered metrics to be stored')
  18. parser.add_argument(
  19. '--not-show', action='store_true', help='not show metrics')
  20. parser.add_argument(
  21. '--show-all', action='store_true', help='show all model metrics')
  22. args = parser.parse_args()
  23. return args
  24. if __name__ == '__main__':
  25. args = parse_args()
  26. root_path = args.root
  27. metrics_out = args.out
  28. result_dict = {}
  29. cfg = Config.fromfile(args.config)
  30. for model_key in cfg:
  31. model_infos = cfg[model_key]
  32. if not isinstance(model_infos, list):
  33. model_infos = [model_infos]
  34. for model_info in model_infos:
  35. record_metrics = model_info['metric']
  36. config = model_info['config'].strip()
  37. fname, _ = osp.splitext(osp.basename(config))
  38. metric_json_dir = osp.join(root_path, fname)
  39. if osp.exists(metric_json_dir):
  40. json_list = glob.glob(osp.join(metric_json_dir, '*.json'))
  41. if len(json_list) > 0:
  42. log_json_path = list(sorted(json_list))[-1]
  43. metric = load(log_json_path)
  44. if config in metric.get('config', {}):
  45. new_metrics = dict()
  46. for record_metric_key in record_metrics:
  47. record_metric_key_bk = record_metric_key
  48. old_metric = record_metrics[record_metric_key]
  49. if record_metric_key == 'AR_1000':
  50. record_metric_key = 'AR@1000'
  51. if record_metric_key not in metric['metric']:
  52. raise KeyError(
  53. 'record_metric_key not exist, please '
  54. 'check your config')
  55. new_metric = round(
  56. metric['metric'][record_metric_key] * 100, 1)
  57. new_metrics[record_metric_key_bk] = new_metric
  58. if args.show_all:
  59. result_dict[config] = dict(
  60. before=record_metrics, after=new_metrics)
  61. else:
  62. for record_metric_key in record_metrics:
  63. old_metric = record_metrics[record_metric_key]
  64. new_metric = new_metrics[record_metric_key]
  65. if old_metric != new_metric:
  66. result_dict[config] = dict(
  67. before=record_metrics,
  68. after=new_metrics)
  69. break
  70. else:
  71. print(f'{config} not included in: {log_json_path}')
  72. else:
  73. print(f'{config} not exist file: {metric_json_dir}')
  74. else:
  75. print(f'{config} not exist dir: {metric_json_dir}')
  76. if metrics_out:
  77. mkdir_or_exist(metrics_out)
  78. dump(result_dict, osp.join(metrics_out, 'batch_test_metric_info.json'))
  79. if not args.not_show:
  80. print('===================================')
  81. for config_name, metrics in result_dict.items():
  82. print(config_name, metrics)
  83. print('===================================')