benchmark.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os
  4. from mmengine import MMLogger
  5. from mmengine.config import Config, DictAction
  6. from mmengine.dist import init_dist
  7. from mmengine.registry import init_default_scope
  8. from mmengine.utils import mkdir_or_exist
  9. from mmdet.utils.benchmark import (DataLoaderBenchmark, DatasetBenchmark,
  10. InferenceBenchmark)
  11. def parse_args():
  12. parser = argparse.ArgumentParser(description='MMDet benchmark')
  13. parser.add_argument('config', help='test config file path')
  14. parser.add_argument('--checkpoint', help='checkpoint file')
  15. parser.add_argument(
  16. '--task',
  17. choices=['inference', 'dataloader', 'dataset'],
  18. default='dataloader',
  19. help='Which task do you want to go to benchmark')
  20. parser.add_argument(
  21. '--repeat-num',
  22. type=int,
  23. default=1,
  24. help='number of repeat times of measurement for averaging the results')
  25. parser.add_argument(
  26. '--max-iter', type=int, default=2000, help='num of max iter')
  27. parser.add_argument(
  28. '--log-interval', type=int, default=50, help='interval of logging')
  29. parser.add_argument(
  30. '--num-warmup', type=int, default=5, help='Number of warmup')
  31. parser.add_argument(
  32. '--fuse-conv-bn',
  33. action='store_true',
  34. help='Whether to fuse conv and bn, this will slightly increase'
  35. 'the inference speed')
  36. parser.add_argument(
  37. '--dataset-type',
  38. choices=['train', 'val', 'test'],
  39. default='test',
  40. help='Benchmark dataset type. only supports train, val and test')
  41. parser.add_argument(
  42. '--work-dir',
  43. help='the directory to save the file containing '
  44. 'benchmark metrics')
  45. parser.add_argument(
  46. '--cfg-options',
  47. nargs='+',
  48. action=DictAction,
  49. help='override some settings in the used config, the key-value pair '
  50. 'in xxx=yyy format will be merged into config file. If the value to '
  51. 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
  52. 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
  53. 'Note that the quotation marks are necessary and that no white space '
  54. 'is allowed.')
  55. parser.add_argument(
  56. '--launcher',
  57. choices=['none', 'pytorch', 'slurm', 'mpi'],
  58. default='none',
  59. help='job launcher')
  60. parser.add_argument('--local_rank', type=int, default=0)
  61. args = parser.parse_args()
  62. if 'LOCAL_RANK' not in os.environ:
  63. os.environ['LOCAL_RANK'] = str(args.local_rank)
  64. return args
  65. def inference_benchmark(args, cfg, distributed, logger):
  66. benchmark = InferenceBenchmark(
  67. cfg,
  68. args.checkpoint,
  69. distributed,
  70. args.fuse_conv_bn,
  71. args.max_iter,
  72. args.log_interval,
  73. args.num_warmup,
  74. logger=logger)
  75. return benchmark
  76. def dataloader_benchmark(args, cfg, distributed, logger):
  77. benchmark = DataLoaderBenchmark(
  78. cfg,
  79. distributed,
  80. args.dataset_type,
  81. args.max_iter,
  82. args.log_interval,
  83. args.num_warmup,
  84. logger=logger)
  85. return benchmark
  86. def dataset_benchmark(args, cfg, distributed, logger):
  87. benchmark = DatasetBenchmark(
  88. cfg,
  89. args.dataset_type,
  90. args.max_iter,
  91. args.log_interval,
  92. args.num_warmup,
  93. logger=logger)
  94. return benchmark
  95. def main():
  96. args = parse_args()
  97. cfg = Config.fromfile(args.config)
  98. if args.cfg_options is not None:
  99. cfg.merge_from_dict(args.cfg_options)
  100. init_default_scope(cfg.get('default_scope', 'mmdet'))
  101. distributed = False
  102. if args.launcher != 'none':
  103. init_dist(args.launcher, **cfg.get('env_cfg', {}).get('dist_cfg', {}))
  104. distributed = True
  105. log_file = None
  106. if args.work_dir:
  107. log_file = os.path.join(args.work_dir, 'benchmark.log')
  108. mkdir_or_exist(args.work_dir)
  109. logger = MMLogger.get_instance(
  110. 'mmdet', log_file=log_file, log_level='INFO')
  111. benchmark = eval(f'{args.task}_benchmark')(args, cfg, distributed, logger)
  112. benchmark.run(args.repeat_num)
  113. if __name__ == '__main__':
  114. main()