# Copyright (c) OpenMMLab. All rights reserved. import mimetypes import os import time from argparse import ArgumentParser import cv2 import json_tricks as json import mmcv import mmengine import numpy as np from mmpose.apis import inference_topdown from mmpose.apis import init_model as init_pose_estimator from mmpose.evaluation.functional import nms from mmpose.registry import VISUALIZERS from mmpose.structures import merge_data_samples, split_instances from mmpose.utils import adapt_mmdet_pipeline try: from mmdet.apis import inference_detector, init_detector has_mmdet = True except (ImportError, ModuleNotFoundError): has_mmdet = False def process_one_image(args, img, detector, pose_estimator, visualizer=None, show_interval=0): """Visualize predicted keypoints (and heatmaps) of one image.""" # predict bbox det_result = inference_detector(detector, img) pred_instance = det_result.pred_instances.cpu().numpy() bboxes = np.concatenate( (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1) bboxes = bboxes[np.logical_and(pred_instance.labels == args.det_cat_id, pred_instance.scores > args.bbox_thr)] bboxes = bboxes[nms(bboxes, args.nms_thr), :4] # predict keypoints pose_results = inference_topdown(pose_estimator, img, bboxes) data_samples = merge_data_samples(pose_results) # show the results if isinstance(img, str): img = mmcv.imread(img, channel_order='rgb') elif isinstance(img, np.ndarray): img = mmcv.bgr2rgb(img) if visualizer is not None: visualizer.add_datasample( 'result', img, data_sample=data_samples, draw_gt=False, draw_heatmap=args.draw_heatmap, draw_bbox=args.draw_bbox, show_kpt_idx=args.show_kpt_idx, skeleton_style=args.skeleton_style, show=args.show, wait_time=show_interval, kpt_thr=args.kpt_thr) # if there is no instance detected, return None return data_samples.get('pred_instances', None) def main(): """Visualize the demo images. Using mmdet to detect the human. """ parser = ArgumentParser() parser.add_argument('det_config', help='Config file for detection') parser.add_argument('det_checkpoint', help='Checkpoint file for detection') parser.add_argument('pose_config', help='Config file for pose') parser.add_argument('pose_checkpoint', help='Checkpoint file for pose') parser.add_argument( '--input', type=str, default='', help='Image/Video file') parser.add_argument( '--show', action='store_true', default=False, help='whether to show img') parser.add_argument( '--output-root', type=str, default='', help='root of the output img file. ' 'Default not saving the visualization images.') parser.add_argument( '--save-predictions', action='store_true', default=False, help='whether to save predicted results') parser.add_argument( '--device', default='cuda:0', help='Device used for inference') parser.add_argument( '--det-cat-id', type=int, default=0, help='Category id for bounding box detection model') parser.add_argument( '--bbox-thr', type=float, default=0.3, help='Bounding box score threshold') parser.add_argument( '--nms-thr', type=float, default=0.3, help='IoU threshold for bounding box NMS') parser.add_argument( '--kpt-thr', type=float, default=0.3, help='Visualizing keypoint thresholds') parser.add_argument( '--draw-heatmap', action='store_true', default=False, help='Draw heatmap predicted by the model') parser.add_argument( '--show-kpt-idx', action='store_true', default=False, help='Whether to show the index of keypoints') parser.add_argument( '--skeleton-style', default='mmpose', type=str, choices=['mmpose', 'openpose'], help='Skeleton style selection') parser.add_argument( '--radius', type=int, default=3, help='Keypoint radius for visualization') parser.add_argument( '--thickness', type=int, default=1, help='Link thickness for visualization') parser.add_argument( '--show-interval', type=int, default=0, help='Sleep seconds per frame') parser.add_argument( '--alpha', type=float, default=0.8, help='The transparency of bboxes') parser.add_argument( '--draw-bbox', action='store_true', help='Draw bboxes of instances') assert has_mmdet, 'Please install mmdet to run the demo.' args = parser.parse_args() assert args.show or (args.output_root != '') assert args.input != '' assert args.det_config is not None assert args.det_checkpoint is not None output_file = None if args.output_root: mmengine.mkdir_or_exist(args.output_root) output_file = os.path.join(args.output_root, os.path.basename(args.input)) if args.input == 'webcam': output_file += '.mp4' if args.save_predictions: assert args.output_root != '' args.pred_save_path = f'{args.output_root}/results_' \ f'{os.path.splitext(os.path.basename(args.input))[0]}.json' # build detector detector = init_detector( args.det_config, args.det_checkpoint, device=args.device) detector.cfg = adapt_mmdet_pipeline(detector.cfg) # build pose estimator pose_estimator = init_pose_estimator( args.pose_config, args.pose_checkpoint, device=args.device, cfg_options=dict( model=dict(test_cfg=dict(output_heatmaps=args.draw_heatmap)))) # build visualizer pose_estimator.cfg.visualizer.radius = args.radius pose_estimator.cfg.visualizer.alpha = args.alpha pose_estimator.cfg.visualizer.line_width = args.thickness visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer) # the dataset_meta is loaded from the checkpoint and # then pass to the model in init_pose_estimator visualizer.set_dataset_meta( pose_estimator.dataset_meta, skeleton_style=args.skeleton_style) if args.input == 'webcam': input_type = 'webcam' else: input_type = mimetypes.guess_type(args.input)[0].split('/')[0] if input_type == 'image': # inference pred_instances = process_one_image(args, args.input, detector, pose_estimator, visualizer) if args.save_predictions: pred_instances_list = split_instances(pred_instances) if output_file: img_vis = visualizer.get_image() mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file) elif input_type in ['webcam', 'video']: if args.input == 'webcam': cap = cv2.VideoCapture(0) else: cap = cv2.VideoCapture(args.input) video_writer = None pred_instances_list = [] frame_idx = 0 while cap.isOpened(): success, frame = cap.read() frame_idx += 1 if not success: break # topdown pose estimation pred_instances = process_one_image(args, frame, detector, pose_estimator, visualizer, 0.001) if args.save_predictions: # save prediction results pred_instances_list.append( dict( frame_id=frame_idx, instances=split_instances(pred_instances))) # output videos if output_file: frame_vis = visualizer.get_image() if video_writer is None: fourcc = cv2.VideoWriter_fourcc(*'mp4v') # the size of the image with visualization may vary # depending on the presence of heatmaps video_writer = cv2.VideoWriter( output_file, fourcc, 25, # saved fps (frame_vis.shape[1], frame_vis.shape[0])) video_writer.write(mmcv.rgb2bgr(frame_vis)) # press ESC to exit if cv2.waitKey(5) & 0xFF == 27: break time.sleep(args.show_interval) if video_writer: video_writer.release() cap.release() else: args.save_predictions = False raise ValueError( f'file {os.path.basename(args.input)} has invalid format.') if args.save_predictions: with open(args.pred_save_path, 'w') as f: json.dump( dict( meta_info=pose_estimator.dataset_meta, instance_info=pred_instances_list), f, indent='\t') print(f'predictions have been saved at {args.pred_save_path}') if __name__ == '__main__': main()