topdown_demo_with_mmdet.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import mimetypes
  3. import os
  4. import time
  5. from argparse import ArgumentParser
  6. import cv2
  7. import json_tricks as json
  8. import mmcv
  9. import mmengine
  10. import numpy as np
  11. from mmpose.apis import inference_topdown
  12. from mmpose.apis import init_model as init_pose_estimator
  13. from mmpose.evaluation.functional import nms
  14. from mmpose.registry import VISUALIZERS
  15. from mmpose.structures import merge_data_samples, split_instances
  16. from mmpose.utils import adapt_mmdet_pipeline
  17. try:
  18. from mmdet.apis import inference_detector, init_detector
  19. has_mmdet = True
  20. except (ImportError, ModuleNotFoundError):
  21. has_mmdet = False
  22. def process_one_image(args,
  23. img,
  24. detector,
  25. pose_estimator,
  26. visualizer=None,
  27. show_interval=0):
  28. """Visualize predicted keypoints (and heatmaps) of one image."""
  29. # predict bbox
  30. det_result = inference_detector(detector, img)
  31. pred_instance = det_result.pred_instances.cpu().numpy()
  32. bboxes = np.concatenate(
  33. (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
  34. bboxes = bboxes[np.logical_and(pred_instance.labels == args.det_cat_id,
  35. pred_instance.scores > args.bbox_thr)]
  36. bboxes = bboxes[nms(bboxes, args.nms_thr), :4]
  37. # predict keypoints
  38. pose_results = inference_topdown(pose_estimator, img, bboxes)
  39. data_samples = merge_data_samples(pose_results)
  40. # show the results
  41. if isinstance(img, str):
  42. img = mmcv.imread(img, channel_order='rgb')
  43. elif isinstance(img, np.ndarray):
  44. img = mmcv.bgr2rgb(img)
  45. if visualizer is not None:
  46. visualizer.add_datasample(
  47. 'result',
  48. img,
  49. data_sample=data_samples,
  50. draw_gt=False,
  51. draw_heatmap=args.draw_heatmap,
  52. draw_bbox=args.draw_bbox,
  53. show_kpt_idx=args.show_kpt_idx,
  54. skeleton_style=args.skeleton_style,
  55. show=args.show,
  56. wait_time=show_interval,
  57. kpt_thr=args.kpt_thr)
  58. # if there is no instance detected, return None
  59. return data_samples.get('pred_instances', None)
  60. def main():
  61. """Visualize the demo images.
  62. Using mmdet to detect the human.
  63. """
  64. parser = ArgumentParser()
  65. parser.add_argument('det_config', help='Config file for detection')
  66. parser.add_argument('det_checkpoint', help='Checkpoint file for detection')
  67. parser.add_argument('pose_config', help='Config file for pose')
  68. parser.add_argument('pose_checkpoint', help='Checkpoint file for pose')
  69. parser.add_argument(
  70. '--input', type=str, default='', help='Image/Video file')
  71. parser.add_argument(
  72. '--show',
  73. action='store_true',
  74. default=False,
  75. help='whether to show img')
  76. parser.add_argument(
  77. '--output-root',
  78. type=str,
  79. default='',
  80. help='root of the output img file. '
  81. 'Default not saving the visualization images.')
  82. parser.add_argument(
  83. '--save-predictions',
  84. action='store_true',
  85. default=False,
  86. help='whether to save predicted results')
  87. parser.add_argument(
  88. '--device', default='cuda:0', help='Device used for inference')
  89. parser.add_argument(
  90. '--det-cat-id',
  91. type=int,
  92. default=0,
  93. help='Category id for bounding box detection model')
  94. parser.add_argument(
  95. '--bbox-thr',
  96. type=float,
  97. default=0.3,
  98. help='Bounding box score threshold')
  99. parser.add_argument(
  100. '--nms-thr',
  101. type=float,
  102. default=0.3,
  103. help='IoU threshold for bounding box NMS')
  104. parser.add_argument(
  105. '--kpt-thr',
  106. type=float,
  107. default=0.3,
  108. help='Visualizing keypoint thresholds')
  109. parser.add_argument(
  110. '--draw-heatmap',
  111. action='store_true',
  112. default=False,
  113. help='Draw heatmap predicted by the model')
  114. parser.add_argument(
  115. '--show-kpt-idx',
  116. action='store_true',
  117. default=False,
  118. help='Whether to show the index of keypoints')
  119. parser.add_argument(
  120. '--skeleton-style',
  121. default='mmpose',
  122. type=str,
  123. choices=['mmpose', 'openpose'],
  124. help='Skeleton style selection')
  125. parser.add_argument(
  126. '--radius',
  127. type=int,
  128. default=3,
  129. help='Keypoint radius for visualization')
  130. parser.add_argument(
  131. '--thickness',
  132. type=int,
  133. default=1,
  134. help='Link thickness for visualization')
  135. parser.add_argument(
  136. '--show-interval', type=int, default=0, help='Sleep seconds per frame')
  137. parser.add_argument(
  138. '--alpha', type=float, default=0.8, help='The transparency of bboxes')
  139. parser.add_argument(
  140. '--draw-bbox', action='store_true', help='Draw bboxes of instances')
  141. assert has_mmdet, 'Please install mmdet to run the demo.'
  142. args = parser.parse_args()
  143. assert args.show or (args.output_root != '')
  144. assert args.input != ''
  145. assert args.det_config is not None
  146. assert args.det_checkpoint is not None
  147. output_file = None
  148. if args.output_root:
  149. mmengine.mkdir_or_exist(args.output_root)
  150. output_file = os.path.join(args.output_root,
  151. os.path.basename(args.input))
  152. if args.input == 'webcam':
  153. output_file += '.mp4'
  154. if args.save_predictions:
  155. assert args.output_root != ''
  156. args.pred_save_path = f'{args.output_root}/results_' \
  157. f'{os.path.splitext(os.path.basename(args.input))[0]}.json'
  158. # build detector
  159. detector = init_detector(
  160. args.det_config, args.det_checkpoint, device=args.device)
  161. detector.cfg = adapt_mmdet_pipeline(detector.cfg)
  162. # build pose estimator
  163. pose_estimator = init_pose_estimator(
  164. args.pose_config,
  165. args.pose_checkpoint,
  166. device=args.device,
  167. cfg_options=dict(
  168. model=dict(test_cfg=dict(output_heatmaps=args.draw_heatmap))))
  169. # build visualizer
  170. pose_estimator.cfg.visualizer.radius = args.radius
  171. pose_estimator.cfg.visualizer.alpha = args.alpha
  172. pose_estimator.cfg.visualizer.line_width = args.thickness
  173. visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer)
  174. # the dataset_meta is loaded from the checkpoint and
  175. # then pass to the model in init_pose_estimator
  176. visualizer.set_dataset_meta(
  177. pose_estimator.dataset_meta, skeleton_style=args.skeleton_style)
  178. if args.input == 'webcam':
  179. input_type = 'webcam'
  180. else:
  181. input_type = mimetypes.guess_type(args.input)[0].split('/')[0]
  182. if input_type == 'image':
  183. # inference
  184. pred_instances = process_one_image(args, args.input, detector,
  185. pose_estimator, visualizer)
  186. if args.save_predictions:
  187. pred_instances_list = split_instances(pred_instances)
  188. if output_file:
  189. img_vis = visualizer.get_image()
  190. mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file)
  191. elif input_type in ['webcam', 'video']:
  192. if args.input == 'webcam':
  193. cap = cv2.VideoCapture(0)
  194. else:
  195. cap = cv2.VideoCapture(args.input)
  196. video_writer = None
  197. pred_instances_list = []
  198. frame_idx = 0
  199. while cap.isOpened():
  200. success, frame = cap.read()
  201. frame_idx += 1
  202. if not success:
  203. break
  204. # topdown pose estimation
  205. pred_instances = process_one_image(args, frame, detector,
  206. pose_estimator, visualizer,
  207. 0.001)
  208. if args.save_predictions:
  209. # save prediction results
  210. pred_instances_list.append(
  211. dict(
  212. frame_id=frame_idx,
  213. instances=split_instances(pred_instances)))
  214. # output videos
  215. if output_file:
  216. frame_vis = visualizer.get_image()
  217. if video_writer is None:
  218. fourcc = cv2.VideoWriter_fourcc(*'mp4v')
  219. # the size of the image with visualization may vary
  220. # depending on the presence of heatmaps
  221. video_writer = cv2.VideoWriter(
  222. output_file,
  223. fourcc,
  224. 25, # saved fps
  225. (frame_vis.shape[1], frame_vis.shape[0]))
  226. video_writer.write(mmcv.rgb2bgr(frame_vis))
  227. # press ESC to exit
  228. if cv2.waitKey(5) & 0xFF == 27:
  229. break
  230. time.sleep(args.show_interval)
  231. if video_writer:
  232. video_writer.release()
  233. cap.release()
  234. else:
  235. args.save_predictions = False
  236. raise ValueError(
  237. f'file {os.path.basename(args.input)} has invalid format.')
  238. if args.save_predictions:
  239. with open(args.pred_save_path, 'w') as f:
  240. json.dump(
  241. dict(
  242. meta_info=pose_estimator.dataset_meta,
  243. instance_info=pred_instances_list),
  244. f,
  245. indent='\t')
  246. print(f'predictions have been saved at {args.pred_save_path}')
  247. if __name__ == '__main__':
  248. main()