bottomup_demo.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import mimetypes
  3. import os
  4. import time
  5. from argparse import ArgumentParser
  6. import cv2
  7. import json_tricks as json
  8. import mmcv
  9. import mmengine
  10. import numpy as np
  11. from mmpose.apis import inference_bottomup, init_model
  12. from mmpose.registry import VISUALIZERS
  13. from mmpose.structures import split_instances
  14. def process_one_image(args,
  15. img,
  16. pose_estimator,
  17. visualizer=None,
  18. show_interval=0):
  19. """Visualize predicted keypoints (and heatmaps) of one image."""
  20. # inference a single image
  21. batch_results = inference_bottomup(pose_estimator, img)
  22. results = batch_results[0]
  23. # show the results
  24. if isinstance(img, str):
  25. img = mmcv.imread(img, channel_order='rgb')
  26. elif isinstance(img, np.ndarray):
  27. img = mmcv.bgr2rgb(img)
  28. if visualizer is not None:
  29. visualizer.add_datasample(
  30. 'result',
  31. img,
  32. data_sample=results,
  33. draw_gt=False,
  34. draw_bbox=False,
  35. draw_heatmap=args.draw_heatmap,
  36. show_kpt_idx=args.show_kpt_idx,
  37. show=args.show,
  38. wait_time=show_interval,
  39. kpt_thr=args.kpt_thr)
  40. return results.pred_instances
  41. def parse_args():
  42. parser = ArgumentParser()
  43. parser.add_argument('config', help='Config file')
  44. parser.add_argument('checkpoint', help='Checkpoint file')
  45. parser.add_argument(
  46. '--input', type=str, default='', help='Image/Video file')
  47. parser.add_argument(
  48. '--show',
  49. action='store_true',
  50. default=False,
  51. help='whether to show img')
  52. parser.add_argument(
  53. '--output-root',
  54. type=str,
  55. default='',
  56. help='root of the output img file. '
  57. 'Default not saving the visualization images.')
  58. parser.add_argument(
  59. '--save-predictions',
  60. action='store_true',
  61. default=False,
  62. help='whether to save predicted results')
  63. parser.add_argument(
  64. '--device', default='cuda:0', help='Device used for inference')
  65. parser.add_argument(
  66. '--draw-heatmap',
  67. action='store_true',
  68. help='Visualize the predicted heatmap')
  69. parser.add_argument(
  70. '--show-kpt-idx',
  71. action='store_true',
  72. default=False,
  73. help='Whether to show the index of keypoints')
  74. parser.add_argument(
  75. '--kpt-thr', type=float, default=0.3, help='Keypoint score threshold')
  76. parser.add_argument(
  77. '--radius',
  78. type=int,
  79. default=3,
  80. help='Keypoint radius for visualization')
  81. parser.add_argument(
  82. '--thickness',
  83. type=int,
  84. default=1,
  85. help='Link thickness for visualization')
  86. parser.add_argument(
  87. '--show-interval', type=int, default=0, help='Sleep seconds per frame')
  88. args = parser.parse_args()
  89. return args
  90. def main():
  91. args = parse_args()
  92. assert args.show or (args.output_root != '')
  93. assert args.input != ''
  94. output_file = None
  95. if args.output_root:
  96. mmengine.mkdir_or_exist(args.output_root)
  97. output_file = os.path.join(args.output_root,
  98. os.path.basename(args.input))
  99. if args.input == 'webcam':
  100. output_file += '.mp4'
  101. if args.save_predictions:
  102. assert args.output_root != ''
  103. args.pred_save_path = f'{args.output_root}/results_' \
  104. f'{os.path.splitext(os.path.basename(args.input))[0]}.json'
  105. # build the model from a config file and a checkpoint file
  106. if args.draw_heatmap:
  107. cfg_options = dict(model=dict(test_cfg=dict(output_heatmaps=True)))
  108. else:
  109. cfg_options = None
  110. model = init_model(
  111. args.config,
  112. args.checkpoint,
  113. device=args.device,
  114. cfg_options=cfg_options)
  115. # build visualizer
  116. model.cfg.visualizer.radius = args.radius
  117. model.cfg.visualizer.line_width = args.thickness
  118. visualizer = VISUALIZERS.build(model.cfg.visualizer)
  119. visualizer.set_dataset_meta(model.dataset_meta)
  120. if args.input == 'webcam':
  121. input_type = 'webcam'
  122. else:
  123. input_type = mimetypes.guess_type(args.input)[0].split('/')[0]
  124. if input_type == 'image':
  125. # inference
  126. pred_instances = process_one_image(
  127. args, args.input, model, visualizer, show_interval=0)
  128. if args.save_predictions:
  129. pred_instances_list = split_instances(pred_instances)
  130. if output_file:
  131. img_vis = visualizer.get_image()
  132. mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file)
  133. elif input_type in ['webcam', 'video']:
  134. if args.input == 'webcam':
  135. cap = cv2.VideoCapture(0)
  136. else:
  137. cap = cv2.VideoCapture(args.input)
  138. video_writer = None
  139. pred_instances_list = []
  140. frame_idx = 0
  141. while cap.isOpened():
  142. success, frame = cap.read()
  143. frame_idx += 1
  144. if not success:
  145. break
  146. pred_instances = process_one_image(args, frame, model, visualizer,
  147. 0.001)
  148. if args.save_predictions:
  149. # save prediction results
  150. pred_instances_list.append(
  151. dict(
  152. frame_id=frame_idx,
  153. instances=split_instances(pred_instances)))
  154. # output videos
  155. if output_file:
  156. frame_vis = visualizer.get_image()
  157. if video_writer is None:
  158. fourcc = cv2.VideoWriter_fourcc(*'mp4v')
  159. # the size of the image with visualization may vary
  160. # depending on the presence of heatmaps
  161. video_writer = cv2.VideoWriter(
  162. output_file,
  163. fourcc,
  164. 25, # saved fps
  165. (frame_vis.shape[1], frame_vis.shape[0]))
  166. video_writer.write(mmcv.rgb2bgr(frame_vis))
  167. # press ESC to exit
  168. if cv2.waitKey(5) & 0xFF == 27:
  169. break
  170. time.sleep(args.show_interval)
  171. if video_writer:
  172. video_writer.release()
  173. cap.release()
  174. else:
  175. args.save_predictions = False
  176. raise ValueError(
  177. f'file {os.path.basename(args.input)} has invalid format.')
  178. if args.save_predictions:
  179. with open(args.pred_save_path, 'w') as f:
  180. json.dump(
  181. dict(
  182. meta_info=model.dataset_meta,
  183. instance_info=pred_instances_list),
  184. f,
  185. indent='\t')
  186. print(f'predictions have been saved at {args.pred_save_path}')
  187. if __name__ == '__main__':
  188. main()