image_demo.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from argparse import ArgumentParser
  3. from mmcv.image import imread
  4. from mmpose.apis import inference_topdown, init_model
  5. from mmpose.registry import VISUALIZERS
  6. from mmpose.structures import merge_data_samples
  7. def parse_args():
  8. parser = ArgumentParser()
  9. parser.add_argument('img', help='Image file')
  10. parser.add_argument('config', help='Config file')
  11. parser.add_argument('checkpoint', help='Checkpoint file')
  12. parser.add_argument('--out-file', default=None, help='Path to output file')
  13. parser.add_argument(
  14. '--device', default='cuda:0', help='Device used for inference')
  15. parser.add_argument(
  16. '--draw-heatmap',
  17. action='store_true',
  18. help='Visualize the predicted heatmap')
  19. parser.add_argument(
  20. '--show-kpt-idx',
  21. action='store_true',
  22. default=False,
  23. help='Whether to show the index of keypoints')
  24. parser.add_argument(
  25. '--skeleton-style',
  26. default='mmpose',
  27. type=str,
  28. choices=['mmpose', 'openpose'],
  29. help='Skeleton style selection')
  30. parser.add_argument(
  31. '--kpt-thr',
  32. type=float,
  33. default=0.3,
  34. help='Visualizing keypoint thresholds')
  35. parser.add_argument(
  36. '--radius',
  37. type=int,
  38. default=3,
  39. help='Keypoint radius for visualization')
  40. parser.add_argument(
  41. '--thickness',
  42. type=int,
  43. default=1,
  44. help='Link thickness for visualization')
  45. parser.add_argument(
  46. '--alpha', type=float, default=0.8, help='The transparency of bboxes')
  47. parser.add_argument(
  48. '--show',
  49. action='store_true',
  50. default=False,
  51. help='whether to show img')
  52. args = parser.parse_args()
  53. return args
  54. def main():
  55. args = parse_args()
  56. # build the model from a config file and a checkpoint file
  57. if args.draw_heatmap:
  58. cfg_options = dict(model=dict(test_cfg=dict(output_heatmaps=True)))
  59. else:
  60. cfg_options = None
  61. model = init_model(
  62. args.config,
  63. args.checkpoint,
  64. device=args.device,
  65. cfg_options=cfg_options)
  66. # init visualizer
  67. model.cfg.visualizer.radius = args.radius
  68. model.cfg.visualizer.alpha = args.alpha
  69. model.cfg.visualizer.line_width = args.thickness
  70. visualizer = VISUALIZERS.build(model.cfg.visualizer)
  71. visualizer.set_dataset_meta(
  72. model.dataset_meta, skeleton_style=args.skeleton_style)
  73. # inference a single image
  74. batch_results = inference_topdown(model, args.img)
  75. results = merge_data_samples(batch_results)
  76. # show the results
  77. img = imread(args.img, channel_order='rgb')
  78. visualizer.add_datasample(
  79. 'result',
  80. img,
  81. data_sample=results,
  82. draw_gt=False,
  83. draw_bbox=True,
  84. kpt_thr=args.kpt_thr,
  85. draw_heatmap=args.draw_heatmap,
  86. show_kpt_idx=args.show_kpt_idx,
  87. skeleton_style=args.skeleton_style,
  88. show=args.show,
  89. out_file=args.out_file)
  90. if __name__ == '__main__':
  91. main()