openpose_visualization.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. import mimetypes
  4. import os
  5. from argparse import ArgumentParser
  6. from itertools import product
  7. import cv2
  8. import mmcv
  9. import numpy as np
  10. from mmengine.registry import init_default_scope
  11. from mmpose.apis import inference_topdown
  12. from mmpose.apis import init_model as init_pose_estimator
  13. from mmpose.evaluation.functional import nms
  14. from mmpose.structures import merge_data_samples
  15. try:
  16. from mmdet.apis import inference_detector, init_detector
  17. has_mmdet = True
  18. except (ImportError, ModuleNotFoundError):
  19. has_mmdet = False
  20. # openpose format
  21. limb_seq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10],
  22. [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17],
  23. [1, 16], [16, 18]]
  24. colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255,
  25. 0], [170, 255, 0],
  26. [85, 255, 0], [0, 255, 0],
  27. [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255],
  28. [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0,
  29. 255], [255, 0, 255],
  30. [255, 0, 170], [255, 0, 85]]
  31. stickwidth = 4
  32. num_openpose_kpt = 18
  33. num_link = len(limb_seq)
  34. def mmpose_to_openpose_visualization(args, img_path, detector, pose_estimator):
  35. """Visualize predicted keypoints of one image in openpose format."""
  36. # predict bbox
  37. init_default_scope(detector.cfg.get('default_scope', 'mmdet'))
  38. det_result = inference_detector(detector, img_path)
  39. pred_instance = det_result.pred_instances.cpu().numpy()
  40. bboxes = np.concatenate(
  41. (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
  42. bboxes = bboxes[np.logical_and(pred_instance.labels == args.det_cat_id,
  43. pred_instance.scores > args.bbox_thr)]
  44. bboxes = bboxes[nms(bboxes, args.nms_thr), :4]
  45. # predict keypoints
  46. pose_results = inference_topdown(pose_estimator, img_path, bboxes)
  47. data_samples = merge_data_samples(pose_results)
  48. # concatenate scores and keypoints
  49. keypoints = np.concatenate(
  50. (data_samples.pred_instances.keypoints,
  51. data_samples.pred_instances.keypoint_scores.reshape(-1, 17, 1)),
  52. axis=-1)
  53. # compute neck joint
  54. neck = (keypoints[:, 5] + keypoints[:, 6]) / 2
  55. if keypoints[:, 5, 2] < args.kpt_thr or keypoints[:, 6, 2] < args.kpt_thr:
  56. neck[:, 2] = 0
  57. # 17 keypoints to 18 keypoints
  58. new_keypoints = np.insert(keypoints[:, ], 17, neck, axis=1)
  59. # mmpose format to openpose format
  60. openpose_idx = [15, 14, 17, 16, 2, 6, 3, 7, 4, 8, 12, 9, 13, 10, 1]
  61. mmpose_idx = [1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17]
  62. new_keypoints[:, openpose_idx, :] = new_keypoints[:, mmpose_idx, :]
  63. # show the results
  64. img = mmcv.imread(img_path, channel_order='rgb')
  65. # black background
  66. black_img = np.zeros_like(img)
  67. num_instance = new_keypoints.shape[0]
  68. # draw keypoints
  69. for i, j in product(range(num_instance), range(num_openpose_kpt)):
  70. x, y, conf = new_keypoints[i][j]
  71. if conf > args.kpt_thr:
  72. cv2.circle(black_img, (int(x), int(y)), 4, colors[j], thickness=-1)
  73. # draw links
  74. cur_black_img = black_img.copy()
  75. for i, link_idx in product(range(num_instance), range(num_link)):
  76. conf = new_keypoints[i][np.array(limb_seq[link_idx]) - 1, 2]
  77. if np.sum(conf > args.kpt_thr) == 2:
  78. Y = new_keypoints[i][np.array(limb_seq[link_idx]) - 1, 0]
  79. X = new_keypoints[i][np.array(limb_seq[link_idx]) - 1, 1]
  80. mX = np.mean(X)
  81. mY = np.mean(Y)
  82. length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5
  83. angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
  84. polygon = cv2.ellipse2Poly(
  85. (int(mY), int(mX)), (int(length / 2), stickwidth), int(angle),
  86. 0, 360, 1)
  87. cv2.fillConvexPoly(cur_black_img, polygon, colors[link_idx])
  88. black_img = cv2.addWeighted(black_img, 0.4, cur_black_img, 0.6, 0)
  89. # save image
  90. out_file = 'openpose_' + os.path.splitext(
  91. os.path.basename(img_path))[0] + '.png'
  92. cv2.imwrite(out_file, black_img[:, :, [2, 1, 0]])
  93. def main():
  94. """Visualize the demo images.
  95. Using mmdet to detect the human.
  96. """
  97. parser = ArgumentParser()
  98. parser.add_argument('det_config', help='Config file for detection')
  99. parser.add_argument('det_checkpoint', help='Checkpoint file for detection')
  100. parser.add_argument('pose_config', help='Config file for pose')
  101. parser.add_argument('pose_checkpoint', help='Checkpoint file for pose')
  102. parser.add_argument('--input', type=str, help='input Image file')
  103. parser.add_argument(
  104. '--device', default='cuda:0', help='Device used for inference')
  105. parser.add_argument(
  106. '--det-cat-id',
  107. type=int,
  108. default=0,
  109. help='Category id for bounding box detection model')
  110. parser.add_argument(
  111. '--bbox-thr',
  112. type=float,
  113. default=0.4,
  114. help='Bounding box score threshold')
  115. parser.add_argument(
  116. '--nms-thr',
  117. type=float,
  118. default=0.3,
  119. help='IoU threshold for bounding box NMS')
  120. parser.add_argument(
  121. '--kpt-thr', type=float, default=0.4, help='Keypoint score threshold')
  122. assert has_mmdet, 'Please install mmdet to run the demo.'
  123. args = parser.parse_args()
  124. assert args.input != ''
  125. assert args.det_config is not None
  126. assert args.det_checkpoint is not None
  127. # build detector
  128. detector = init_detector(
  129. args.det_config, args.det_checkpoint, device=args.device)
  130. # build pose estimator
  131. pose_estimator = init_pose_estimator(
  132. args.pose_config,
  133. args.pose_checkpoint,
  134. device=args.device,
  135. cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=False))))
  136. input_type = mimetypes.guess_type(args.input)[0].split('/')[0]
  137. if input_type == 'image':
  138. mmpose_to_openpose_visualization(args, args.input, detector,
  139. pose_estimator)
  140. if __name__ == '__main__':
  141. main()