preprocess_mpi_inf_3dhp.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os
  4. import pickle
  5. import shutil
  6. from os.path import join
  7. import cv2
  8. import h5py
  9. import mmcv
  10. import numpy as np
  11. from scipy.io import loadmat
  12. train_subjects = [i for i in range(1, 9)]
  13. test_subjects = [i for i in range(1, 7)]
  14. train_seqs = [1, 2]
  15. train_cams = [0, 1, 2, 4, 5, 6, 7, 8]
  16. train_frame_nums = {
  17. (1, 1): 6416,
  18. (1, 2): 12430,
  19. (2, 1): 6502,
  20. (2, 2): 6081,
  21. (3, 1): 12488,
  22. (3, 2): 12283,
  23. (4, 1): 6171,
  24. (4, 2): 6675,
  25. (5, 1): 12820,
  26. (5, 2): 12312,
  27. (6, 1): 6188,
  28. (6, 2): 6145,
  29. (7, 1): 6239,
  30. (7, 2): 6320,
  31. (8, 1): 6468,
  32. (8, 2): 6054
  33. }
  34. test_frame_nums = {1: 6151, 2: 6080, 3: 5838, 4: 6007, 5: 320, 6: 492}
  35. train_img_size = (2048, 2048)
  36. root_index = 14
  37. joints_17 = [7, 5, 14, 15, 16, 9, 10, 11, 23, 24, 25, 18, 19, 20, 4, 3, 6]
  38. def get_pose_stats(kps):
  39. """Get statistic information `mean` and `std` of pose data.
  40. Args:
  41. kps (ndarray): keypoints in shape [..., K, D] where K and D is
  42. the keypoint category number and dimension.
  43. Returns:
  44. mean (ndarray): [K, D]
  45. """
  46. assert kps.ndim > 2
  47. K, D = kps.shape[-2:]
  48. kps = kps.reshape(-1, K, D)
  49. mean = kps.mean(axis=0)
  50. std = kps.std(axis=0)
  51. return mean, std
  52. def get_annotations(joints_2d, joints_3d, scale_factor=1.2):
  53. """Get annotations, including centers, scales, joints_2d and joints_3d.
  54. Args:
  55. joints_2d: 2D joint coordinates in shape [N, K, 2], where N is the
  56. frame number, K is the joint number.
  57. joints_3d: 3D joint coordinates in shape [N, K, 3], where N is the
  58. frame number, K is the joint number.
  59. scale_factor: Scale factor of bounding box. Default: 1.2.
  60. Returns:
  61. centers (ndarray): [N, 2]
  62. scales (ndarray): [N,]
  63. joints_2d (ndarray): [N, K, 3]
  64. joints_3d (ndarray): [N, K, 4]
  65. """
  66. # calculate joint visibility
  67. visibility = (joints_2d[:, :, 0] >= 0) * \
  68. (joints_2d[:, :, 0] < train_img_size[0]) * \
  69. (joints_2d[:, :, 1] >= 0) * \
  70. (joints_2d[:, :, 1] < train_img_size[1])
  71. visibility = np.array(visibility, dtype=np.float32)[:, :, None]
  72. joints_2d = np.concatenate([joints_2d, visibility], axis=-1)
  73. joints_3d = np.concatenate([joints_3d, visibility], axis=-1)
  74. # calculate bounding boxes
  75. bboxes = np.stack([
  76. np.min(joints_2d[:, :, 0], axis=1),
  77. np.min(joints_2d[:, :, 1], axis=1),
  78. np.max(joints_2d[:, :, 0], axis=1),
  79. np.max(joints_2d[:, :, 1], axis=1)
  80. ],
  81. axis=1)
  82. centers = np.stack([(bboxes[:, 0] + bboxes[:, 2]) / 2,
  83. (bboxes[:, 1] + bboxes[:, 3]) / 2],
  84. axis=1)
  85. scales = scale_factor * np.max(bboxes[:, 2:] - bboxes[:, :2], axis=1) / 200
  86. return centers, scales, joints_2d, joints_3d
  87. def load_trainset(data_root, out_dir):
  88. """Load training data, create annotation file and camera file.
  89. Args:
  90. data_root: Directory of dataset, which is organized in the following
  91. hierarchy:
  92. data_root
  93. |-- train
  94. |-- S1
  95. |-- Seq1
  96. |-- Seq2
  97. |-- S2
  98. |-- ...
  99. |-- test
  100. |-- TS1
  101. |-- TS2
  102. |-- ...
  103. out_dir: Directory to save annotation file.
  104. """
  105. _imgnames = []
  106. _centers = []
  107. _scales = []
  108. _joints_2d = []
  109. _joints_3d = []
  110. cameras = {}
  111. img_dir = join(out_dir, 'images')
  112. os.makedirs(img_dir, exist_ok=True)
  113. annot_dir = join(out_dir, 'annotations')
  114. os.makedirs(annot_dir, exist_ok=True)
  115. for subj in train_subjects:
  116. for seq in train_seqs:
  117. seq_path = join(data_root, 'train', f'S{subj}', f'Seq{seq}')
  118. num_frames = train_frame_nums[(subj, seq)]
  119. # load camera parametres
  120. camera_file = join(seq_path, 'camera.calibration')
  121. with open(camera_file, 'r') as fin:
  122. lines = fin.readlines()
  123. for cam in train_cams:
  124. K = [float(s) for s in lines[cam * 7 + 5][11:-2].split()]
  125. f = np.array([[K[0]], [K[5]]])
  126. c = np.array([[K[2]], [K[6]]])
  127. RT = np.array(
  128. [float(s) for s in lines[cam * 7 + 6][11:-2].split()])
  129. RT = np.reshape(RT, (4, 4))
  130. R = RT[:3, :3]
  131. # convert unit from millimeter to meter
  132. T = RT[:3, 3:] * 0.001
  133. size = [int(s) for s in lines[cam * 7 + 3][14:].split()]
  134. w, h = size
  135. cam_param = dict(
  136. R=R, T=T, c=c, f=f, w=w, h=h, name=f'train_cam_{cam}')
  137. cameras[f'S{subj}_Seq{seq}_Cam{cam}'] = cam_param
  138. # load annotations
  139. annot_file = os.path.join(seq_path, 'annot.mat')
  140. annot2 = loadmat(annot_file)['annot2']
  141. annot3 = loadmat(annot_file)['annot3']
  142. for cam in train_cams:
  143. # load 2D and 3D annotations
  144. joints_2d = np.reshape(annot2[cam][0][:num_frames],
  145. (num_frames, 28, 2))[:, joints_17]
  146. joints_3d = np.reshape(annot3[cam][0][:num_frames],
  147. (num_frames, 28, 3))[:, joints_17]
  148. joints_3d = joints_3d * 0.001
  149. centers, scales, joints_2d, joints_3d = get_annotations(
  150. joints_2d, joints_3d)
  151. _centers.append(centers)
  152. _scales.append(scales)
  153. _joints_2d.append(joints_2d)
  154. _joints_3d.append(joints_3d)
  155. # extract frames from video
  156. video_path = join(seq_path, 'imageSequence',
  157. f'video_{cam}.avi')
  158. video = mmcv.VideoReader(video_path)
  159. for i in mmcv.track_iter_progress(range(num_frames)):
  160. img = video.read()
  161. if img is None:
  162. break
  163. imgname = f'S{subj}_Seq{seq}_Cam{cam}_{i+1:06d}.jpg'
  164. _imgnames.append(imgname)
  165. cv2.imwrite(join(img_dir, imgname), img)
  166. _imgnames = np.array(_imgnames)
  167. _centers = np.concatenate(_centers)
  168. _scales = np.concatenate(_scales)
  169. _joints_2d = np.concatenate(_joints_2d)
  170. _joints_3d = np.concatenate(_joints_3d)
  171. out_file = join(annot_dir, 'mpi_inf_3dhp_train.npz')
  172. np.savez(
  173. out_file,
  174. imgname=_imgnames,
  175. center=_centers,
  176. scale=_scales,
  177. part=_joints_2d,
  178. S=_joints_3d)
  179. print(f'Create annotation file for trainset: {out_file}. '
  180. f'{len(_imgnames)} samples in total.')
  181. out_file = join(annot_dir, 'cameras_train.pkl')
  182. with open(out_file, 'wb') as fout:
  183. pickle.dump(cameras, fout)
  184. print(f'Create camera file for trainset: {out_file}.')
  185. # get `mean` and `std` of pose data
  186. _joints_3d = _joints_3d[..., :3] # remove visibility
  187. mean_3d, std_3d = get_pose_stats(_joints_3d)
  188. _joints_2d = _joints_2d[..., :2] # remove visibility
  189. mean_2d, std_2d = get_pose_stats(_joints_2d)
  190. # centered around root
  191. _joints_3d_rel = _joints_3d - _joints_3d[..., root_index:root_index + 1, :]
  192. mean_3d_rel, std_3d_rel = get_pose_stats(_joints_3d_rel)
  193. mean_3d_rel[root_index] = mean_3d[root_index]
  194. std_3d_rel[root_index] = std_3d[root_index]
  195. _joints_2d_rel = _joints_2d - _joints_2d[..., root_index:root_index + 1, :]
  196. mean_2d_rel, std_2d_rel = get_pose_stats(_joints_2d_rel)
  197. mean_2d_rel[root_index] = mean_2d[root_index]
  198. std_2d_rel[root_index] = std_2d[root_index]
  199. stats = {
  200. 'joint3d_stats': {
  201. 'mean': mean_3d,
  202. 'std': std_3d
  203. },
  204. 'joint2d_stats': {
  205. 'mean': mean_2d,
  206. 'std': std_2d
  207. },
  208. 'joint3d_rel_stats': {
  209. 'mean': mean_3d_rel,
  210. 'std': std_3d_rel
  211. },
  212. 'joint2d_rel_stats': {
  213. 'mean': mean_2d_rel,
  214. 'std': std_2d_rel
  215. }
  216. }
  217. for name, stat_dict in stats.items():
  218. out_file = join(annot_dir, f'{name}.pkl')
  219. with open(out_file, 'wb') as f:
  220. pickle.dump(stat_dict, f)
  221. print(f'Create statistic data file: {out_file}')
  222. def load_testset(data_root, out_dir, valid_only=True):
  223. """Load testing data, create annotation file and camera file.
  224. Args:
  225. data_root: Directory of dataset.
  226. out_dir: Directory to save annotation file.
  227. valid_only: Only keep frames with valid_label == 1.
  228. """
  229. _imgnames = []
  230. _centers = []
  231. _scales = []
  232. _joints_2d = []
  233. _joints_3d = []
  234. cameras = {}
  235. img_dir = join(out_dir, 'images')
  236. os.makedirs(img_dir, exist_ok=True)
  237. annot_dir = join(out_dir, 'annotations')
  238. os.makedirs(annot_dir, exist_ok=True)
  239. for subj in test_subjects:
  240. subj_path = join(data_root, 'test', f'TS{subj}')
  241. num_frames = test_frame_nums[subj]
  242. # load annotations
  243. annot_file = os.path.join(subj_path, 'annot_data.mat')
  244. with h5py.File(annot_file, 'r') as fin:
  245. annot2 = np.array(fin['annot2']).reshape((-1, 17, 2))
  246. annot3 = np.array(fin['annot3']).reshape((-1, 17, 3))
  247. valid = np.array(fin['valid_frame']).reshape(-1)
  248. # manually estimate camera intrinsics
  249. fx, cx = np.linalg.lstsq(
  250. annot3[:, :, [0, 2]].reshape((-1, 2)),
  251. (annot2[:, :, 0] * annot3[:, :, 2]).reshape(-1, 1),
  252. rcond=None)[0].flatten()
  253. fy, cy = np.linalg.lstsq(
  254. annot3[:, :, [1, 2]].reshape((-1, 2)),
  255. (annot2[:, :, 1] * annot3[:, :, 2]).reshape(-1, 1),
  256. rcond=None)[0].flatten()
  257. if subj <= 4:
  258. w, h = 2048, 2048
  259. else:
  260. w, h = 1920, 1080
  261. cameras[f'TS{subj}'] = dict(
  262. c=np.array([[cx], [cy]]),
  263. f=np.array([[fx], [fy]]),
  264. w=w,
  265. h=h,
  266. name=f'test_cam_{subj}')
  267. # get annotations
  268. if valid_only:
  269. valid_frames = np.nonzero(valid)[0]
  270. else:
  271. valid_frames = np.arange(num_frames)
  272. joints_2d = annot2[valid_frames, :, :]
  273. joints_3d = annot3[valid_frames, :, :] * 0.001
  274. centers, scales, joints_2d, joints_3d = get_annotations(
  275. joints_2d, joints_3d)
  276. _centers.append(centers)
  277. _scales.append(scales)
  278. _joints_2d.append(joints_2d)
  279. _joints_3d.append(joints_3d)
  280. # copy and rename images
  281. for i in valid_frames:
  282. imgname = f'TS{subj}_{i+1:06d}.jpg'
  283. shutil.copyfile(
  284. join(subj_path, 'imageSequence', f'img_{i+1:06d}.jpg'),
  285. join(img_dir, imgname))
  286. _imgnames.append(imgname)
  287. _imgnames = np.array(_imgnames)
  288. _centers = np.concatenate(_centers)
  289. _scales = np.concatenate(_scales)
  290. _joints_2d = np.concatenate(_joints_2d)
  291. _joints_3d = np.concatenate(_joints_3d)
  292. if valid_only:
  293. out_file = join(annot_dir, 'mpi_inf_3dhp_test_valid.npz')
  294. else:
  295. out_file = join(annot_dir, 'mpi_inf_3dhp_test_all.npz')
  296. np.savez(
  297. out_file,
  298. imgname=_imgnames,
  299. center=_centers,
  300. scale=_scales,
  301. part=_joints_2d,
  302. S=_joints_3d)
  303. print(f'Create annotation file for testset: {out_file}. '
  304. f'{len(_imgnames)} samples in total.')
  305. out_file = join(annot_dir, 'cameras_test.pkl')
  306. with open(out_file, 'wb') as fout:
  307. pickle.dump(cameras, fout)
  308. print(f'Create camera file for testset: {out_file}.')
  309. if __name__ == '__main__':
  310. parser = argparse.ArgumentParser()
  311. parser.add_argument('data_root', type=str, help='data root')
  312. parser.add_argument(
  313. 'out_dir', type=str, help='directory to save annotation files.')
  314. args = parser.parse_args()
  315. data_root = args.data_root
  316. out_dir = args.out_dir
  317. load_trainset(data_root, out_dir)
  318. load_testset(data_root, out_dir, valid_only=True)