benchmark_test_image.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import logging
  3. import os.path as osp
  4. from argparse import ArgumentParser
  5. import mmcv
  6. from mmengine.config import Config
  7. from mmengine.logging import MMLogger
  8. from mmengine.utils import mkdir_or_exist
  9. from mmdet.apis import inference_detector, init_detector
  10. from mmdet.registry import VISUALIZERS
  11. from mmdet.utils import register_all_modules
  12. def parse_args():
  13. parser = ArgumentParser()
  14. parser.add_argument('config', help='test config file path')
  15. parser.add_argument('checkpoint_root', help='Checkpoint file root path')
  16. parser.add_argument('--img', default='demo/demo.jpg', help='Image file')
  17. parser.add_argument('--aug', action='store_true', help='aug test')
  18. parser.add_argument('--model-name', help='model name to inference')
  19. parser.add_argument('--show', action='store_true', help='show results')
  20. parser.add_argument('--out-dir', default=None, help='Dir to output file')
  21. parser.add_argument(
  22. '--wait-time',
  23. type=float,
  24. default=1,
  25. help='the interval of show (s), 0 is block')
  26. parser.add_argument(
  27. '--device', default='cuda:0', help='Device used for inference')
  28. parser.add_argument(
  29. '--palette',
  30. default='coco',
  31. choices=['coco', 'voc', 'citys', 'random'],
  32. help='Color palette used for visualization')
  33. parser.add_argument(
  34. '--score-thr', type=float, default=0.3, help='bbox score threshold')
  35. args = parser.parse_args()
  36. return args
  37. def inference_model(config_name, checkpoint, visualizer, args, logger=None):
  38. cfg = Config.fromfile(config_name)
  39. if args.aug:
  40. raise NotImplementedError()
  41. model = init_detector(
  42. cfg, checkpoint, palette=args.palette, device=args.device)
  43. visualizer.dataset_meta = model.dataset_meta
  44. # test a single image
  45. result = inference_detector(model, args.img)
  46. # show the results
  47. if args.show or args.out_dir is not None:
  48. img = mmcv.imread(args.img)
  49. img = mmcv.imconvert(img, 'bgr', 'rgb')
  50. out_file = None
  51. if args.out_dir is not None:
  52. out_dir = args.out_dir
  53. mkdir_or_exist(out_dir)
  54. out_file = osp.join(
  55. out_dir,
  56. config_name.split('/')[-1].replace('py', 'jpg'))
  57. visualizer.add_datasample(
  58. 'result',
  59. img,
  60. data_sample=result,
  61. draw_gt=False,
  62. show=args.show,
  63. wait_time=args.wait_time,
  64. out_file=out_file,
  65. pred_score_thr=args.score_thr)
  66. return result
  67. # Sample test whether the inference code is correct
  68. def main(args):
  69. # register all modules in mmdet into the registries
  70. register_all_modules()
  71. config = Config.fromfile(args.config)
  72. # init visualizer
  73. visualizer_cfg = dict(type='DetLocalVisualizer', name='visualizer')
  74. visualizer = VISUALIZERS.build(visualizer_cfg)
  75. # test single model
  76. if args.model_name:
  77. if args.model_name in config:
  78. model_infos = config[args.model_name]
  79. if not isinstance(model_infos, list):
  80. model_infos = [model_infos]
  81. model_info = model_infos[0]
  82. config_name = model_info['config'].strip()
  83. print(f'processing: {config_name}', flush=True)
  84. checkpoint = osp.join(args.checkpoint_root,
  85. model_info['checkpoint'].strip())
  86. # build the model from a config file and a checkpoint file
  87. inference_model(config_name, checkpoint, visualizer, args)
  88. return
  89. else:
  90. raise RuntimeError('model name input error.')
  91. # test all model
  92. logger = MMLogger.get_instance(
  93. name='MMLogger',
  94. log_file='benchmark_test_image.log',
  95. log_level=logging.ERROR)
  96. for model_key in config:
  97. model_infos = config[model_key]
  98. if not isinstance(model_infos, list):
  99. model_infos = [model_infos]
  100. for model_info in model_infos:
  101. print('processing: ', model_info['config'], flush=True)
  102. config_name = model_info['config'].strip()
  103. checkpoint = osp.join(args.checkpoint_root,
  104. model_info['checkpoint'].strip())
  105. try:
  106. # build the model from a config file and a checkpoint file
  107. inference_model(config_name, checkpoint, visualizer, args,
  108. logger)
  109. except Exception as e:
  110. logger.error(f'{config_name} " : {repr(e)}')
  111. if __name__ == '__main__':
  112. args = parse_args()
  113. main(args)