benchmark_inference_fps.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os
  4. import os.path as osp
  5. from mmengine.config import Config, DictAction
  6. from mmengine.dist import init_dist
  7. from mmengine.fileio import dump
  8. from mmengine.utils import mkdir_or_exist
  9. from terminaltables import GithubFlavoredMarkdownTable
  10. from tools.analysis_tools.benchmark import repeat_measure_inference_speed
  11. def parse_args():
  12. parser = argparse.ArgumentParser(
  13. description='MMDet benchmark a model of FPS')
  14. parser.add_argument('config', help='test config file path')
  15. parser.add_argument('checkpoint_root', help='Checkpoint file root path')
  16. parser.add_argument(
  17. '--round-num',
  18. type=int,
  19. default=1,
  20. help='round a number to a given precision in decimal digits')
  21. parser.add_argument(
  22. '--repeat-num',
  23. type=int,
  24. default=1,
  25. help='number of repeat times of measurement for averaging the results')
  26. parser.add_argument(
  27. '--out', type=str, help='output path of gathered fps to be stored')
  28. parser.add_argument(
  29. '--max-iter', type=int, default=2000, help='num of max iter')
  30. parser.add_argument(
  31. '--log-interval', type=int, default=50, help='interval of logging')
  32. parser.add_argument(
  33. '--fuse-conv-bn',
  34. action='store_true',
  35. help='Whether to fuse conv and bn, this will slightly increase'
  36. 'the inference speed')
  37. parser.add_argument(
  38. '--cfg-options',
  39. nargs='+',
  40. action=DictAction,
  41. help='override some settings in the used config, the key-value pair '
  42. 'in xxx=yyy format will be merged into config file. If the value to '
  43. 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
  44. 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
  45. 'Note that the quotation marks are necessary and that no white space '
  46. 'is allowed.')
  47. parser.add_argument(
  48. '--launcher',
  49. choices=['none', 'pytorch', 'slurm', 'mpi'],
  50. default='none',
  51. help='job launcher')
  52. parser.add_argument('--local_rank', type=int, default=0)
  53. args = parser.parse_args()
  54. if 'LOCAL_RANK' not in os.environ:
  55. os.environ['LOCAL_RANK'] = str(args.local_rank)
  56. return args
  57. def results2markdown(result_dict):
  58. table_data = []
  59. is_multiple_results = False
  60. for cfg_name, value in result_dict.items():
  61. name = cfg_name.replace('configs/', '')
  62. fps = value['fps']
  63. ms_times_pre_image = value['ms_times_pre_image']
  64. if isinstance(fps, list):
  65. is_multiple_results = True
  66. mean_fps = value['mean_fps']
  67. mean_times_pre_image = value['mean_times_pre_image']
  68. fps_str = ','.join([str(s) for s in fps])
  69. ms_times_pre_image_str = ','.join(
  70. [str(s) for s in ms_times_pre_image])
  71. table_data.append([
  72. name, fps_str, mean_fps, ms_times_pre_image_str,
  73. mean_times_pre_image
  74. ])
  75. else:
  76. table_data.append([name, fps, ms_times_pre_image])
  77. if is_multiple_results:
  78. table_data.insert(0, [
  79. 'model', 'fps', 'mean_fps', 'times_pre_image(ms)',
  80. 'mean_times_pre_image(ms)'
  81. ])
  82. else:
  83. table_data.insert(0, ['model', 'fps', 'times_pre_image(ms)'])
  84. table = GithubFlavoredMarkdownTable(table_data)
  85. print(table.table, flush=True)
  86. if __name__ == '__main__':
  87. args = parse_args()
  88. assert args.round_num >= 0
  89. assert args.repeat_num >= 1
  90. config = Config.fromfile(args.config)
  91. if args.launcher == 'none':
  92. raise NotImplementedError('Only supports distributed mode')
  93. else:
  94. init_dist(args.launcher)
  95. result_dict = {}
  96. for model_key in config:
  97. model_infos = config[model_key]
  98. if not isinstance(model_infos, list):
  99. model_infos = [model_infos]
  100. for model_info in model_infos:
  101. record_metrics = model_info['metric']
  102. cfg_path = model_info['config'].strip()
  103. cfg = Config.fromfile(cfg_path)
  104. checkpoint = osp.join(args.checkpoint_root,
  105. model_info['checkpoint'].strip())
  106. try:
  107. fps = repeat_measure_inference_speed(cfg, checkpoint,
  108. args.max_iter,
  109. args.log_interval,
  110. args.fuse_conv_bn,
  111. args.repeat_num)
  112. if args.repeat_num > 1:
  113. fps_list = [round(fps_, args.round_num) for fps_ in fps]
  114. times_pre_image_list = [
  115. round(1000 / fps_, args.round_num) for fps_ in fps
  116. ]
  117. mean_fps = round(
  118. sum(fps_list) / len(fps_list), args.round_num)
  119. mean_times_pre_image = round(
  120. sum(times_pre_image_list) / len(times_pre_image_list),
  121. args.round_num)
  122. print(
  123. f'{cfg_path} '
  124. f'Overall fps: {fps_list}[{mean_fps}] img / s, '
  125. f'times per image: '
  126. f'{times_pre_image_list}[{mean_times_pre_image}] '
  127. f'ms / img',
  128. flush=True)
  129. result_dict[cfg_path] = dict(
  130. fps=fps_list,
  131. mean_fps=mean_fps,
  132. ms_times_pre_image=times_pre_image_list,
  133. mean_times_pre_image=mean_times_pre_image)
  134. else:
  135. print(
  136. f'{cfg_path} fps : {fps:.{args.round_num}f} img / s, '
  137. f'times per image: {1000 / fps:.{args.round_num}f} '
  138. f'ms / img',
  139. flush=True)
  140. result_dict[cfg_path] = dict(
  141. fps=round(fps, args.round_num),
  142. ms_times_pre_image=round(1000 / fps, args.round_num))
  143. except Exception as e:
  144. print(f'{cfg_path} error: {repr(e)}')
  145. if args.repeat_num > 1:
  146. result_dict[cfg_path] = dict(
  147. fps=[0],
  148. mean_fps=0,
  149. ms_times_pre_image=[0],
  150. mean_times_pre_image=0)
  151. else:
  152. result_dict[cfg_path] = dict(fps=0, ms_times_pre_image=0)
  153. if args.out:
  154. mkdir_or_exist(args.out)
  155. dump(result_dict, osp.join(args.out, 'batch_inference_fps.json'))
  156. results2markdown(result_dict)