123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import argparse
- import os.path as osp
- from functools import wraps
- import mmengine
- import numpy as np
- from PIL import Image
- from mmpose.utils import SimpleCamera
- def _keypoint_camera_to_world(keypoints,
- camera_params,
- image_name=None,
- dataset='Body3DH36MDataset'):
- """Project 3D keypoints from the camera space to the world space.
- Args:
- keypoints (np.ndarray): 3D keypoints in shape [..., 3]
- camera_params (dict): Parameters for all cameras.
- image_name (str): The image name to specify the camera.
- dataset (str): The dataset type, e.g., Body3DH36MDataset.
- """
- cam_key = None
- if dataset == 'Body3DH36MDataset':
- subj, rest = osp.basename(image_name).split('_', 1)
- _, rest = rest.split('.', 1)
- camera, rest = rest.split('_', 1)
- cam_key = (subj, camera)
- else:
- raise NotImplementedError
- camera = SimpleCamera(camera_params[cam_key])
- keypoints_world = keypoints.copy()
- keypoints_world[..., :3] = camera.camera_to_world(keypoints[..., :3])
- return keypoints_world
- def _get_bbox_xywh(center, scale, w=200, h=200):
- w = w * scale
- h = h * scale
- x = center[0] - w / 2
- y = center[1] - h / 2
- return [x, y, w, h]
- def mmcv_track_func(func):
- @wraps(func)
- def wrapped_func(args):
- return func(*args)
- return wrapped_func
- @mmcv_track_func
- def _get_img_info(img_idx, img_name, img_root):
- try:
- im = Image.open(osp.join(img_root, img_name))
- w, h = im.size
- except: # noqa: E722
- return None
- img = {
- 'file_name': img_name,
- 'height': h,
- 'width': w,
- 'id': img_idx + 1,
- }
- return img
- @mmcv_track_func
- def _get_ann(idx, kpt_2d, kpt_3d, center, scale, imgname, camera_params):
- bbox = _get_bbox_xywh(center, scale)
- kpt_3d = _keypoint_camera_to_world(kpt_3d, camera_params, imgname)
- ann = {
- 'id': idx + 1,
- 'category_id': 1,
- 'image_id': idx + 1,
- 'iscrowd': 0,
- 'bbox': bbox,
- 'area': bbox[2] * bbox[3],
- 'num_keypoints': 17,
- 'keypoints': kpt_2d.reshape(-1).tolist(),
- 'keypoints_3d': kpt_3d.reshape(-1).tolist()
- }
- return ann
- def main():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- '--ann-file', type=str, default='tests/data/h36m/test_h36m_body3d.npz')
- parser.add_argument(
- '--camera-param-file', type=str, default='tests/data/h36m/cameras.pkl')
- parser.add_argument('--img-root', type=str, default='tests/data/h36m')
- parser.add_argument(
- '--out-file', type=str, default='tests/data/h36m/h36m_coco.json')
- parser.add_argument('--full-img-name', action='store_true')
- args = parser.parse_args()
- h36m_data = np.load(args.ann_file)
- h36m_camera_params = mmengine.load(args.camera_param_file)
- h36m_coco = {}
- # categories
- h36m_cats = [{
- 'supercategory':
- 'person',
- 'id':
- 1,
- 'name':
- 'person',
- 'keypoints': [
- 'root (pelvis)', 'left_hip', 'left_knee', 'left_foot', 'right_hip',
- 'right_knee', 'right_foot', 'spine', 'thorax', 'neck_base', 'head',
- 'left_shoulder', 'left_elbow', 'left_wrist', 'right_shoulder',
- 'right_elbow', 'right_wrist'
- ],
- 'skeleton': [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5], [5, 6], [0, 7],
- [7, 8], [8, 9], [9, 10], [8, 11], [11, 12], [12, 13],
- [8, 14], [14, 15], [15, 16]],
- }]
- # images
- imgnames = h36m_data['imgname']
- if not args.full_img_name:
- imgnames = [osp.basename(fn) for fn in imgnames]
- tasks = [(idx, fn, args.img_root) for idx, fn in enumerate(imgnames)]
- h36m_imgs = mmengine.track_parallel_progress(
- _get_img_info, tasks, nproc=12)
- # annotations
- kpts_2d = h36m_data['part']
- kpts_3d = h36m_data['S']
- centers = h36m_data['center']
- scales = h36m_data['scale']
- tasks = [(idx, ) + args + (h36m_camera_params, )
- for idx, args in enumerate(
- zip(kpts_2d, kpts_3d, centers, scales, imgnames))]
- h36m_anns = mmengine.track_parallel_progress(_get_ann, tasks, nproc=12)
- # remove invalid data
- h36m_imgs = [img for img in h36m_imgs if img is not None]
- h36m_img_ids = set([img['id'] for img in h36m_imgs])
- h36m_anns = [ann for ann in h36m_anns if ann['image_id'] in h36m_img_ids]
- h36m_coco = {
- 'categories': h36m_cats,
- 'images': h36m_imgs,
- 'annotations': h36m_anns,
- }
- mmengine.dump(h36m_coco, args.out_file)
- if __name__ == '__main__':
- main()
|