123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from argparse import ArgumentParser
- from mmcv.image import imread
- from mmpose.apis import inference_topdown, init_model
- from mmpose.registry import VISUALIZERS
- from mmpose.structures import merge_data_samples
- def parse_args():
- parser = ArgumentParser()
- parser.add_argument('img', help='Image file')
- parser.add_argument('config', help='Config file')
- parser.add_argument('checkpoint', help='Checkpoint file')
- parser.add_argument('--out-file', default=None, help='Path to output file')
- parser.add_argument(
- '--device', default='cuda:0', help='Device used for inference')
- parser.add_argument(
- '--draw-heatmap',
- action='store_true',
- help='Visualize the predicted heatmap')
- parser.add_argument(
- '--show-kpt-idx',
- action='store_true',
- default=False,
- help='Whether to show the index of keypoints')
- parser.add_argument(
- '--skeleton-style',
- default='mmpose',
- type=str,
- choices=['mmpose', 'openpose'],
- help='Skeleton style selection')
- parser.add_argument(
- '--kpt-thr',
- type=float,
- default=0.3,
- help='Visualizing keypoint thresholds')
- parser.add_argument(
- '--radius',
- type=int,
- default=3,
- help='Keypoint radius for visualization')
- parser.add_argument(
- '--thickness',
- type=int,
- default=1,
- help='Link thickness for visualization')
- parser.add_argument(
- '--alpha', type=float, default=0.8, help='The transparency of bboxes')
- parser.add_argument(
- '--show',
- action='store_true',
- default=False,
- help='whether to show img')
- args = parser.parse_args()
- return args
- def main():
- args = parse_args()
- # build the model from a config file and a checkpoint file
- if args.draw_heatmap:
- cfg_options = dict(model=dict(test_cfg=dict(output_heatmaps=True)))
- else:
- cfg_options = None
- model = init_model(
- args.config,
- args.checkpoint,
- device=args.device,
- cfg_options=cfg_options)
- # init visualizer
- model.cfg.visualizer.radius = args.radius
- model.cfg.visualizer.alpha = args.alpha
- model.cfg.visualizer.line_width = args.thickness
- visualizer = VISUALIZERS.build(model.cfg.visualizer)
- visualizer.set_dataset_meta(
- model.dataset_meta, skeleton_style=args.skeleton_style)
- # inference a single image
- batch_results = inference_topdown(model, args.img)
- results = merge_data_samples(batch_results)
- # show the results
- img = imread(args.img, channel_order='rgb')
- visualizer.add_datasample(
- 'result',
- img,
- data_sample=results,
- draw_gt=False,
- draw_bbox=True,
- kpt_thr=args.kpt_thr,
- draw_heatmap=args.draw_heatmap,
- show_kpt_idx=args.show_kpt_idx,
- skeleton_style=args.skeleton_style,
- show=args.show,
- out_file=args.out_file)
- if __name__ == '__main__':
- main()
|