h36m_to_coco.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os.path as osp
  4. from functools import wraps
  5. import mmengine
  6. import numpy as np
  7. from PIL import Image
  8. from mmpose.utils import SimpleCamera
  9. def _keypoint_camera_to_world(keypoints,
  10. camera_params,
  11. image_name=None,
  12. dataset='Body3DH36MDataset'):
  13. """Project 3D keypoints from the camera space to the world space.
  14. Args:
  15. keypoints (np.ndarray): 3D keypoints in shape [..., 3]
  16. camera_params (dict): Parameters for all cameras.
  17. image_name (str): The image name to specify the camera.
  18. dataset (str): The dataset type, e.g., Body3DH36MDataset.
  19. """
  20. cam_key = None
  21. if dataset == 'Body3DH36MDataset':
  22. subj, rest = osp.basename(image_name).split('_', 1)
  23. _, rest = rest.split('.', 1)
  24. camera, rest = rest.split('_', 1)
  25. cam_key = (subj, camera)
  26. else:
  27. raise NotImplementedError
  28. camera = SimpleCamera(camera_params[cam_key])
  29. keypoints_world = keypoints.copy()
  30. keypoints_world[..., :3] = camera.camera_to_world(keypoints[..., :3])
  31. return keypoints_world
  32. def _get_bbox_xywh(center, scale, w=200, h=200):
  33. w = w * scale
  34. h = h * scale
  35. x = center[0] - w / 2
  36. y = center[1] - h / 2
  37. return [x, y, w, h]
  38. def mmcv_track_func(func):
  39. @wraps(func)
  40. def wrapped_func(args):
  41. return func(*args)
  42. return wrapped_func
  43. @mmcv_track_func
  44. def _get_img_info(img_idx, img_name, img_root):
  45. try:
  46. im = Image.open(osp.join(img_root, img_name))
  47. w, h = im.size
  48. except: # noqa: E722
  49. return None
  50. img = {
  51. 'file_name': img_name,
  52. 'height': h,
  53. 'width': w,
  54. 'id': img_idx + 1,
  55. }
  56. return img
  57. @mmcv_track_func
  58. def _get_ann(idx, kpt_2d, kpt_3d, center, scale, imgname, camera_params):
  59. bbox = _get_bbox_xywh(center, scale)
  60. kpt_3d = _keypoint_camera_to_world(kpt_3d, camera_params, imgname)
  61. ann = {
  62. 'id': idx + 1,
  63. 'category_id': 1,
  64. 'image_id': idx + 1,
  65. 'iscrowd': 0,
  66. 'bbox': bbox,
  67. 'area': bbox[2] * bbox[3],
  68. 'num_keypoints': 17,
  69. 'keypoints': kpt_2d.reshape(-1).tolist(),
  70. 'keypoints_3d': kpt_3d.reshape(-1).tolist()
  71. }
  72. return ann
  73. def main():
  74. parser = argparse.ArgumentParser()
  75. parser.add_argument(
  76. '--ann-file', type=str, default='tests/data/h36m/test_h36m_body3d.npz')
  77. parser.add_argument(
  78. '--camera-param-file', type=str, default='tests/data/h36m/cameras.pkl')
  79. parser.add_argument('--img-root', type=str, default='tests/data/h36m')
  80. parser.add_argument(
  81. '--out-file', type=str, default='tests/data/h36m/h36m_coco.json')
  82. parser.add_argument('--full-img-name', action='store_true')
  83. args = parser.parse_args()
  84. h36m_data = np.load(args.ann_file)
  85. h36m_camera_params = mmengine.load(args.camera_param_file)
  86. h36m_coco = {}
  87. # categories
  88. h36m_cats = [{
  89. 'supercategory':
  90. 'person',
  91. 'id':
  92. 1,
  93. 'name':
  94. 'person',
  95. 'keypoints': [
  96. 'root (pelvis)', 'left_hip', 'left_knee', 'left_foot', 'right_hip',
  97. 'right_knee', 'right_foot', 'spine', 'thorax', 'neck_base', 'head',
  98. 'left_shoulder', 'left_elbow', 'left_wrist', 'right_shoulder',
  99. 'right_elbow', 'right_wrist'
  100. ],
  101. 'skeleton': [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5], [5, 6], [0, 7],
  102. [7, 8], [8, 9], [9, 10], [8, 11], [11, 12], [12, 13],
  103. [8, 14], [14, 15], [15, 16]],
  104. }]
  105. # images
  106. imgnames = h36m_data['imgname']
  107. if not args.full_img_name:
  108. imgnames = [osp.basename(fn) for fn in imgnames]
  109. tasks = [(idx, fn, args.img_root) for idx, fn in enumerate(imgnames)]
  110. h36m_imgs = mmengine.track_parallel_progress(
  111. _get_img_info, tasks, nproc=12)
  112. # annotations
  113. kpts_2d = h36m_data['part']
  114. kpts_3d = h36m_data['S']
  115. centers = h36m_data['center']
  116. scales = h36m_data['scale']
  117. tasks = [(idx, ) + args + (h36m_camera_params, )
  118. for idx, args in enumerate(
  119. zip(kpts_2d, kpts_3d, centers, scales, imgnames))]
  120. h36m_anns = mmengine.track_parallel_progress(_get_ann, tasks, nproc=12)
  121. # remove invalid data
  122. h36m_imgs = [img for img in h36m_imgs if img is not None]
  123. h36m_img_ids = set([img['id'] for img in h36m_imgs])
  124. h36m_anns = [ann for ann in h36m_anns if ann['image_id'] in h36m_img_ids]
  125. h36m_coco = {
  126. 'categories': h36m_cats,
  127. 'images': h36m_imgs,
  128. 'annotations': h36m_anns,
  129. }
  130. mmengine.dump(h36m_coco, args.out_file)
  131. if __name__ == '__main__':
  132. main()