main.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import time
  4. from typing import List, Tuple
  5. import cv2
  6. import loguru
  7. import numpy as np
  8. import onnxruntime as ort
  9. logger = loguru.logger
  10. def parse_args():
  11. parser = argparse.ArgumentParser(
  12. description='RTMPose ONNX inference demo.')
  13. parser.add_argument('onnx_file', help='ONNX file path')
  14. parser.add_argument('image_file', help='Input image file path')
  15. parser.add_argument(
  16. '--device', help='device type for inference', default='cpu')
  17. parser.add_argument(
  18. '--save-path',
  19. help='path to save the output image',
  20. default='output.jpg')
  21. args = parser.parse_args()
  22. return args
  23. def preprocess(
  24. img: np.ndarray, input_size: Tuple[int, int] = (192, 256)
  25. ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
  26. """Do preprocessing for RTMPose model inference.
  27. Args:
  28. img (np.ndarray): Input image in shape.
  29. input_size (tuple): Input image size in shape (w, h).
  30. Returns:
  31. tuple:
  32. - resized_img (np.ndarray): Preprocessed image.
  33. - center (np.ndarray): Center of image.
  34. - scale (np.ndarray): Scale of image.
  35. """
  36. # get shape of image
  37. img_shape = img.shape[:2]
  38. bbox = np.array([0, 0, img_shape[1], img_shape[0]])
  39. # get center and scale
  40. center, scale = bbox_xyxy2cs(bbox, padding=1.25)
  41. # do affine transformation
  42. resized_img, scale = top_down_affine(input_size, scale, center, img)
  43. # normalize image
  44. mean = np.array([123.675, 116.28, 103.53])
  45. std = np.array([58.395, 57.12, 57.375])
  46. resized_img = (resized_img - mean) / std
  47. return resized_img, center, scale
  48. def build_session(onnx_file: str, device: str = 'cpu') -> ort.InferenceSession:
  49. """Build onnxruntime session.
  50. Args:
  51. onnx_file (str): ONNX file path.
  52. device (str): Device type for inference.
  53. Returns:
  54. sess (ort.InferenceSession): ONNXRuntime session.
  55. """
  56. providers = ['CPUExecutionProvider'
  57. ] if device == 'cpu' else ['CUDAExecutionProvider']
  58. sess = ort.InferenceSession(path_or_bytes=onnx_file, providers=providers)
  59. return sess
  60. def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
  61. """Inference RTMPose model.
  62. Args:
  63. sess (ort.InferenceSession): ONNXRuntime session.
  64. img (np.ndarray): Input image in shape.
  65. Returns:
  66. outputs (np.ndarray): Output of RTMPose model.
  67. """
  68. # build input
  69. input = [img.transpose(2, 0, 1)]
  70. # build output
  71. sess_input = {sess.get_inputs()[0].name: input}
  72. sess_output = []
  73. for out in sess.get_outputs():
  74. sess_output.append(out.name)
  75. # run model
  76. outputs = sess.run(sess_output, sess_input)
  77. return outputs
  78. def postprocess(outputs: List[np.ndarray],
  79. model_input_size: Tuple[int, int],
  80. center: Tuple[int, int],
  81. scale: Tuple[int, int],
  82. simcc_split_ratio: float = 2.0
  83. ) -> Tuple[np.ndarray, np.ndarray]:
  84. """Postprocess for RTMPose model output.
  85. Args:
  86. outputs (np.ndarray): Output of RTMPose model.
  87. model_input_size (tuple): RTMPose model Input image size.
  88. center (tuple): Center of bbox in shape (x, y).
  89. scale (tuple): Scale of bbox in shape (w, h).
  90. simcc_split_ratio (float): Split ratio of simcc.
  91. Returns:
  92. tuple:
  93. - keypoints (np.ndarray): Rescaled keypoints.
  94. - scores (np.ndarray): Model predict scores.
  95. """
  96. # use simcc to decode
  97. simcc_x, simcc_y = outputs
  98. keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
  99. # rescale keypoints
  100. keypoints = keypoints / model_input_size * scale + center - scale / 2
  101. return keypoints, scores
  102. def visualize(img: np.ndarray,
  103. keypoints: np.ndarray,
  104. scores: np.ndarray,
  105. filename: str = 'output.jpg',
  106. thr=0.3) -> np.ndarray:
  107. """Visualize the keypoints and skeleton on image.
  108. Args:
  109. img (np.ndarray): Input image in shape.
  110. keypoints (np.ndarray): Keypoints in image.
  111. scores (np.ndarray): Model predict scores.
  112. thr (float): Threshold for visualize.
  113. Returns:
  114. img (np.ndarray): Visualized image.
  115. """
  116. # default color
  117. skeleton = [(15, 13), (13, 11), (16, 14), (14, 12), (11, 12), (5, 11),
  118. (6, 12), (5, 6), (5, 7), (6, 8), (7, 9), (8, 10), (1, 2),
  119. (0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (15, 17),
  120. (15, 18), (15, 19), (16, 20), (16, 21), (16, 22), (91, 92),
  121. (92, 93), (93, 94), (94, 95), (91, 96), (96, 97), (97, 98),
  122. (98, 99), (91, 100), (100, 101), (101, 102), (102, 103),
  123. (91, 104), (104, 105), (105, 106), (106, 107), (91, 108),
  124. (108, 109), (109, 110), (110, 111), (112, 113), (113, 114),
  125. (114, 115), (115, 116), (112, 117), (117, 118), (118, 119),
  126. (119, 120), (112, 121), (121, 122), (122, 123), (123, 124),
  127. (112, 125), (125, 126), (126, 127), (127, 128), (112, 129),
  128. (129, 130), (130, 131), (131, 132)]
  129. palette = [[51, 153, 255], [0, 255, 0], [255, 128, 0], [255, 255, 255],
  130. [255, 153, 255], [102, 178, 255], [255, 51, 51]]
  131. link_color = [
  132. 1, 1, 2, 2, 0, 0, 0, 0, 1, 2, 1, 2, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 2,
  133. 2, 2, 2, 2, 2, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 1, 1, 1, 1, 2, 2, 2,
  134. 2, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 1, 1, 1, 1
  135. ]
  136. point_color = [
  137. 0, 0, 0, 0, 0, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 3,
  138. 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
  139. 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
  140. 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2,
  141. 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 1, 1, 1, 1, 3, 2, 2, 2, 2, 4, 4, 4,
  142. 4, 5, 5, 5, 5, 6, 6, 6, 6, 1, 1, 1, 1
  143. ]
  144. # draw keypoints and skeleton
  145. for kpts, score in zip(keypoints, scores):
  146. for kpt, color in zip(kpts, point_color):
  147. cv2.circle(img, tuple(kpt.astype(np.int32)), 1, palette[color], 1,
  148. cv2.LINE_AA)
  149. for (u, v), color in zip(skeleton, link_color):
  150. if score[u] > thr and score[v] > thr:
  151. cv2.line(img, tuple(kpts[u].astype(np.int32)),
  152. tuple(kpts[v].astype(np.int32)), palette[color], 2,
  153. cv2.LINE_AA)
  154. # save to local
  155. cv2.imwrite(filename, img)
  156. return img
  157. def bbox_xyxy2cs(bbox: np.ndarray,
  158. padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
  159. """Transform the bbox format from (x,y,w,h) into (center, scale)
  160. Args:
  161. bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
  162. as (left, top, right, bottom)
  163. padding (float): BBox padding factor that will be multilied to scale.
  164. Default: 1.0
  165. Returns:
  166. tuple: A tuple containing center and scale.
  167. - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
  168. (n, 2)
  169. - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
  170. (n, 2)
  171. """
  172. # convert single bbox from (4, ) to (1, 4)
  173. dim = bbox.ndim
  174. if dim == 1:
  175. bbox = bbox[None, :]
  176. # get bbox center and scale
  177. x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
  178. center = np.hstack([x1 + x2, y1 + y2]) * 0.5
  179. scale = np.hstack([x2 - x1, y2 - y1]) * padding
  180. if dim == 1:
  181. center = center[0]
  182. scale = scale[0]
  183. return center, scale
  184. def _fix_aspect_ratio(bbox_scale: np.ndarray,
  185. aspect_ratio: float) -> np.ndarray:
  186. """Extend the scale to match the given aspect ratio.
  187. Args:
  188. scale (np.ndarray): The image scale (w, h) in shape (2, )
  189. aspect_ratio (float): The ratio of ``w/h``
  190. Returns:
  191. np.ndarray: The reshaped image scale in (2, )
  192. """
  193. w, h = np.hsplit(bbox_scale, [1])
  194. bbox_scale = np.where(w > h * aspect_ratio,
  195. np.hstack([w, w / aspect_ratio]),
  196. np.hstack([h * aspect_ratio, h]))
  197. return bbox_scale
  198. def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
  199. """Rotate a point by an angle.
  200. Args:
  201. pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
  202. angle_rad (float): rotation angle in radian
  203. Returns:
  204. np.ndarray: Rotated point in shape (2, )
  205. """
  206. sn, cs = np.sin(angle_rad), np.cos(angle_rad)
  207. rot_mat = np.array([[cs, -sn], [sn, cs]])
  208. return rot_mat @ pt
  209. def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
  210. """To calculate the affine matrix, three pairs of points are required. This
  211. function is used to get the 3rd point, given 2D points a & b.
  212. The 3rd point is defined by rotating vector `a - b` by 90 degrees
  213. anticlockwise, using b as the rotation center.
  214. Args:
  215. a (np.ndarray): The 1st point (x,y) in shape (2, )
  216. b (np.ndarray): The 2nd point (x,y) in shape (2, )
  217. Returns:
  218. np.ndarray: The 3rd point.
  219. """
  220. direction = a - b
  221. c = b + np.r_[-direction[1], direction[0]]
  222. return c
  223. def get_warp_matrix(center: np.ndarray,
  224. scale: np.ndarray,
  225. rot: float,
  226. output_size: Tuple[int, int],
  227. shift: Tuple[float, float] = (0., 0.),
  228. inv: bool = False) -> np.ndarray:
  229. """Calculate the affine transformation matrix that can warp the bbox area
  230. in the input image to the output size.
  231. Args:
  232. center (np.ndarray[2, ]): Center of the bounding box (x, y).
  233. scale (np.ndarray[2, ]): Scale of the bounding box
  234. wrt [width, height].
  235. rot (float): Rotation angle (degree).
  236. output_size (np.ndarray[2, ] | list(2,)): Size of the
  237. destination heatmaps.
  238. shift (0-100%): Shift translation ratio wrt the width/height.
  239. Default (0., 0.).
  240. inv (bool): Option to inverse the affine transform direction.
  241. (inv=False: src->dst or inv=True: dst->src)
  242. Returns:
  243. np.ndarray: A 2x3 transformation matrix
  244. """
  245. shift = np.array(shift)
  246. src_w = scale[0]
  247. dst_w = output_size[0]
  248. dst_h = output_size[1]
  249. # compute transformation matrix
  250. rot_rad = np.deg2rad(rot)
  251. src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
  252. dst_dir = np.array([0., dst_w * -0.5])
  253. # get four corners of the src rectangle in the original image
  254. src = np.zeros((3, 2), dtype=np.float32)
  255. src[0, :] = center + scale * shift
  256. src[1, :] = center + src_dir + scale * shift
  257. src[2, :] = _get_3rd_point(src[0, :], src[1, :])
  258. # get four corners of the dst rectangle in the input image
  259. dst = np.zeros((3, 2), dtype=np.float32)
  260. dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
  261. dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
  262. dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
  263. if inv:
  264. warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
  265. else:
  266. warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
  267. return warp_mat
  268. def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
  269. img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
  270. """Get the bbox image as the model input by affine transform.
  271. Args:
  272. input_size (dict): The input size of the model.
  273. bbox_scale (dict): The bbox scale of the img.
  274. bbox_center (dict): The bbox center of the img.
  275. img (np.ndarray): The original image.
  276. Returns:
  277. tuple: A tuple containing center and scale.
  278. - np.ndarray[float32]: img after affine transform.
  279. - np.ndarray[float32]: bbox scale after affine transform.
  280. """
  281. w, h = input_size
  282. warp_size = (int(w), int(h))
  283. # reshape bbox to fixed aspect ratio
  284. bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
  285. # get the affine matrix
  286. center = bbox_center
  287. scale = bbox_scale
  288. rot = 0
  289. warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
  290. # do affine transform
  291. img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
  292. return img, bbox_scale
  293. def get_simcc_maximum(simcc_x: np.ndarray,
  294. simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
  295. """Get maximum response location and value from simcc representations.
  296. Note:
  297. instance number: N
  298. num_keypoints: K
  299. heatmap height: H
  300. heatmap width: W
  301. Args:
  302. simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
  303. simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
  304. Returns:
  305. tuple:
  306. - locs (np.ndarray): locations of maximum heatmap responses in shape
  307. (K, 2) or (N, K, 2)
  308. - vals (np.ndarray): values of maximum heatmap responses in shape
  309. (K,) or (N, K)
  310. """
  311. N, K, Wx = simcc_x.shape
  312. simcc_x = simcc_x.reshape(N * K, -1)
  313. simcc_y = simcc_y.reshape(N * K, -1)
  314. # get maximum value locations
  315. x_locs = np.argmax(simcc_x, axis=1)
  316. y_locs = np.argmax(simcc_y, axis=1)
  317. locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
  318. max_val_x = np.amax(simcc_x, axis=1)
  319. max_val_y = np.amax(simcc_y, axis=1)
  320. # get maximum value across x and y axis
  321. mask = max_val_x > max_val_y
  322. max_val_x[mask] = max_val_y[mask]
  323. vals = max_val_x
  324. locs[vals <= 0.] = -1
  325. # reshape
  326. locs = locs.reshape(N, K, 2)
  327. vals = vals.reshape(N, K)
  328. return locs, vals
  329. def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
  330. simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
  331. """Modulate simcc distribution with Gaussian.
  332. Args:
  333. simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
  334. simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
  335. simcc_split_ratio (int): The split ratio of simcc.
  336. Returns:
  337. tuple: A tuple containing center and scale.
  338. - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
  339. - np.ndarray[float32]: scores in shape (K,) or (n, K)
  340. """
  341. keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
  342. keypoints /= simcc_split_ratio
  343. return keypoints, scores
  344. def main():
  345. args = parse_args()
  346. logger.info('Start running model on RTMPose...')
  347. # read image from file
  348. logger.info('1. Read image from {}...'.format(args.image_file))
  349. img = cv2.imread(args.image_file)
  350. # build onnx model
  351. logger.info('2. Build onnx model from {}...'.format(args.onnx_file))
  352. sess = build_session(args.onnx_file, args.device)
  353. h, w = sess.get_inputs()[0].shape[2:]
  354. model_input_size = (w, h)
  355. # preprocessing
  356. logger.info('3. Preprocess image...')
  357. resized_img, center, scale = preprocess(img, model_input_size)
  358. # inference
  359. logger.info('4. Inference...')
  360. start_time = time.time()
  361. outputs = inference(sess, resized_img)
  362. end_time = time.time()
  363. logger.info('4. Inference done, time cost: {:.4f}s'.format(end_time -
  364. start_time))
  365. # postprocessing
  366. logger.info('5. Postprocess...')
  367. keypoints, scores = postprocess(outputs, model_input_size, center, scale)
  368. # visualize inference result
  369. logger.info('6. Visualize inference result...')
  370. visualize(img, keypoints, scores, args.save_path)
  371. logger.info('Done...')
  372. if __name__ == '__main__':
  373. main()