preprocess_h36m.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. # -----------------------------------------------------------------------------
  2. # Adapted from https://github.com/anibali/h36m-fetch
  3. # Original license: Copyright (c) Aiden Nibali, under the Apache License.
  4. # -----------------------------------------------------------------------------
  5. import argparse
  6. import os
  7. import pickle
  8. import tarfile
  9. import xml.etree.ElementTree as ET
  10. from os.path import join
  11. import cv2
  12. import numpy as np
  13. from spacepy import pycdf
  14. class PreprocessH36m:
  15. """Preprocess Human3.6M dataset.
  16. Args:
  17. metadata (str): Path to metadata.xml.
  18. original_dir (str): Directory of the original dataset with all files
  19. compressed. Specifically, .tgz files belonging to subject 1
  20. should be placed under the subdirectory 's1'.
  21. extracted_dir (str): Directory of the extracted files. If not given, it
  22. will be placed under the same parent directory as original_dir.
  23. processed_der (str): Directory of the processed files. If not given, it
  24. will be placed under the same parent directory as original_dir.
  25. sample_rate (int): Downsample FPS to `1 / sample_rate`. Default: 5.
  26. """
  27. def __init__(self,
  28. metadata,
  29. original_dir,
  30. extracted_dir=None,
  31. processed_dir=None,
  32. sample_rate=5):
  33. self.metadata = metadata
  34. self.original_dir = original_dir
  35. self.sample_rate = sample_rate
  36. if extracted_dir is None:
  37. self.extracted_dir = join(
  38. os.path.dirname(os.path.abspath(self.original_dir)),
  39. 'extracted')
  40. else:
  41. self.extracted_dir = extracted_dir
  42. if processed_dir is None:
  43. self.processed_dir = join(
  44. os.path.dirname(os.path.abspath(self.original_dir)),
  45. 'processed')
  46. else:
  47. self.processed_dir = processed_dir
  48. self.subjects = []
  49. self.sequence_mappings = {}
  50. self.action_names = {}
  51. self.camera_ids = []
  52. self._load_metadata()
  53. self.subjects_annot = ['S1', 'S5', 'S6', 'S7', 'S8', 'S9', 'S11']
  54. self.subjects_splits = {
  55. 'train': ['S1', 'S5', 'S6', 'S7', 'S8'],
  56. 'test': ['S9', 'S11']
  57. }
  58. self.extract_files = ['Videos', 'D2_Positions', 'D3_Positions_mono']
  59. self.movable_joints = [
  60. 0, 1, 2, 3, 6, 7, 8, 12, 13, 14, 15, 17, 18, 19, 25, 26, 27
  61. ]
  62. self.scale_factor = 1.2
  63. self.image_sizes = {
  64. '54138969': {
  65. 'width': 1000,
  66. 'height': 1002
  67. },
  68. '55011271': {
  69. 'width': 1000,
  70. 'height': 1000
  71. },
  72. '58860488': {
  73. 'width': 1000,
  74. 'height': 1000
  75. },
  76. '60457274': {
  77. 'width': 1000,
  78. 'height': 1002
  79. }
  80. }
  81. def extract_tgz(self):
  82. """Extract files from self.extrct_files."""
  83. os.makedirs(self.extracted_dir, exist_ok=True)
  84. for subject in self.subjects_annot:
  85. cur_dir = join(self.original_dir, subject.lower())
  86. for file in self.extract_files:
  87. filename = join(cur_dir, file + '.tgz')
  88. print(f'Extracting {filename} ...')
  89. with tarfile.open(filename) as tar:
  90. tar.extractall(self.extracted_dir)
  91. print('Extraction done.\n')
  92. def generate_cameras_file(self):
  93. """Generate cameras.pkl which contains camera parameters for 11
  94. subjects each with 4 cameras."""
  95. cameras = {}
  96. for subject in range(1, 12):
  97. for camera in range(4):
  98. key = (f'S{subject}', self.camera_ids[camera])
  99. cameras[key] = self._get_camera_params(camera, subject)
  100. out_file = join(self.processed_dir, 'annotation_body3d', 'cameras.pkl')
  101. with open(out_file, 'wb') as fout:
  102. pickle.dump(cameras, fout)
  103. print(f'Camera parameters have been written to "{out_file}".\n')
  104. def generate_annotations(self):
  105. """Generate annotations for training and testing data."""
  106. output_dir = join(self.processed_dir, 'annotation_body3d',
  107. f'fps{50 // self.sample_rate}')
  108. os.makedirs(output_dir, exist_ok=True)
  109. for data_split in ('train', 'test'):
  110. imgnames_all = []
  111. centers_all = []
  112. scales_all = []
  113. kps2d_all = []
  114. kps3d_all = []
  115. for subject in self.subjects_splits[data_split]:
  116. for action, subaction in self.sequence_mappings[subject].keys(
  117. ):
  118. if action == '1':
  119. # exclude action "_ALL"
  120. continue
  121. for camera in self.camera_ids:
  122. imgnames, centers, scales, kps2d, kps3d\
  123. = self._load_annotations(
  124. subject, action, subaction, camera)
  125. imgnames_all.append(imgnames)
  126. centers_all.append(centers)
  127. scales_all.append(scales)
  128. kps2d_all.append(kps2d)
  129. kps3d_all.append(kps3d)
  130. imgnames_all = np.concatenate(imgnames_all)
  131. centers_all = np.concatenate(centers_all)
  132. scales_all = np.concatenate(scales_all)
  133. kps2d_all = np.concatenate(kps2d_all)
  134. kps3d_all = np.concatenate(kps3d_all)
  135. out_file = join(output_dir, f'h36m_{data_split}.npz')
  136. np.savez(
  137. out_file,
  138. imgname=imgnames_all,
  139. center=centers_all,
  140. scale=scales_all,
  141. part=kps2d_all,
  142. S=kps3d_all)
  143. print(
  144. f'All annotations of {data_split}ing data have been written to'
  145. f' "{out_file}". {len(imgnames_all)} samples in total.\n')
  146. if data_split == 'train':
  147. kps_3d_all = kps3d_all[..., :3] # remove visibility
  148. mean_3d, std_3d = self._get_pose_stats(kps_3d_all)
  149. kps_2d_all = kps2d_all[..., :2] # remove visibility
  150. mean_2d, std_2d = self._get_pose_stats(kps_2d_all)
  151. # centered around root
  152. # the root keypoint is 0-index
  153. kps_3d_rel = kps_3d_all[..., 1:, :] - kps_3d_all[..., :1, :]
  154. mean_3d_rel, std_3d_rel = self._get_pose_stats(kps_3d_rel)
  155. kps_2d_rel = kps_2d_all[..., 1:, :] - kps_2d_all[..., :1, :]
  156. mean_2d_rel, std_2d_rel = self._get_pose_stats(kps_2d_rel)
  157. stats = {
  158. 'joint3d_stats': {
  159. 'mean': mean_3d,
  160. 'std': std_3d
  161. },
  162. 'joint2d_stats': {
  163. 'mean': mean_2d,
  164. 'std': std_2d
  165. },
  166. 'joint3d_rel_stats': {
  167. 'mean': mean_3d_rel,
  168. 'std': std_3d_rel
  169. },
  170. 'joint2d_rel_stats': {
  171. 'mean': mean_2d_rel,
  172. 'std': std_2d_rel
  173. }
  174. }
  175. for name, stat_dict in stats.items():
  176. out_file = join(output_dir, f'{name}.pkl')
  177. with open(out_file, 'wb') as f:
  178. pickle.dump(stat_dict, f)
  179. print(f'Create statistic data file: {out_file}')
  180. @staticmethod
  181. def _get_pose_stats(kps):
  182. """Get statistic information `mean` and `std` of pose data.
  183. Args:
  184. kps (ndarray): keypoints in shape [..., K, D] where K and C is
  185. the keypoint category number and dimension.
  186. Returns:
  187. mean (ndarray): [K, D]
  188. """
  189. assert kps.ndim > 2
  190. K, D = kps.shape[-2:]
  191. kps = kps.reshape(-1, K, D)
  192. mean = kps.mean(axis=0)
  193. std = kps.std(axis=0)
  194. return mean, std
  195. def _load_metadata(self):
  196. """Load meta data from metadata.xml."""
  197. assert os.path.exists(self.metadata)
  198. tree = ET.parse(self.metadata)
  199. root = tree.getroot()
  200. for i, tr in enumerate(root.find('mapping')):
  201. if i == 0:
  202. _, _, *self.subjects = [td.text for td in tr]
  203. self.sequence_mappings \
  204. = {subject: {} for subject in self.subjects}
  205. elif i < 33:
  206. action_id, subaction_id, *prefixes = [td.text for td in tr]
  207. for subject, prefix in zip(self.subjects, prefixes):
  208. self.sequence_mappings[subject][(action_id, subaction_id)]\
  209. = prefix
  210. for i, elem in enumerate(root.find('actionnames')):
  211. action_id = str(i + 1)
  212. self.action_names[action_id] = elem.text
  213. self.camera_ids \
  214. = [elem.text for elem in root.find('dbcameras/index2id')]
  215. w0 = root.find('w0')
  216. self.cameras_raw = [float(num) for num in w0.text[1:-1].split()]
  217. def _get_base_filename(self, subject, action, subaction, camera):
  218. """Get base filename given subject, action, subaction and camera."""
  219. return f'{self.sequence_mappings[subject][(action, subaction)]}' + \
  220. f'.{camera}'
  221. def _get_camera_params(self, camera, subject):
  222. """Get camera parameters given camera id and subject id."""
  223. metadata_slice = np.zeros(15)
  224. start = 6 * (camera * 11 + (subject - 1))
  225. metadata_slice[:6] = self.cameras_raw[start:start + 6]
  226. metadata_slice[6:] = self.cameras_raw[265 + camera * 9 - 1:265 +
  227. (camera + 1) * 9 - 1]
  228. # extrinsics
  229. x, y, z = -metadata_slice[0], metadata_slice[1], -metadata_slice[2]
  230. R_x = np.array([[1, 0, 0], [0, np.cos(x), np.sin(x)],
  231. [0, -np.sin(x), np.cos(x)]])
  232. R_y = np.array([[np.cos(y), 0, np.sin(y)], [0, 1, 0],
  233. [-np.sin(y), 0, np.cos(y)]])
  234. R_z = np.array([[np.cos(z), np.sin(z), 0], [-np.sin(z),
  235. np.cos(z), 0], [0, 0, 1]])
  236. R = (R_x @ R_y @ R_z).T
  237. T = metadata_slice[3:6].reshape(-1, 1)
  238. # convert unit from millimeter to meter
  239. T *= 0.001
  240. # intrinsics
  241. c = metadata_slice[8:10, None]
  242. f = metadata_slice[6:8, None]
  243. # distortion
  244. k = metadata_slice[10:13, None]
  245. p = metadata_slice[13:15, None]
  246. return {
  247. 'R': R,
  248. 'T': T,
  249. 'c': c,
  250. 'f': f,
  251. 'k': k,
  252. 'p': p,
  253. 'w': self.image_sizes[self.camera_ids[camera]]['width'],
  254. 'h': self.image_sizes[self.camera_ids[camera]]['height'],
  255. 'name': f'camera{camera + 1}',
  256. 'id': self.camera_ids[camera]
  257. }
  258. def _load_annotations(self, subject, action, subaction, camera):
  259. """Load annotations for a sequence."""
  260. subj_dir = join(self.extracted_dir, subject)
  261. basename = self._get_base_filename(subject, action, subaction, camera)
  262. # load 2D keypoints
  263. with pycdf.CDF(
  264. join(subj_dir, 'MyPoseFeatures', 'D2_Positions',
  265. basename + '.cdf')) as cdf:
  266. kps_2d = np.array(cdf['Pose'])
  267. num_frames = kps_2d.shape[1]
  268. kps_2d = kps_2d.reshape((num_frames, 32, 2))[::self.sample_rate,
  269. self.movable_joints]
  270. kps_2d = np.concatenate([kps_2d, np.ones((len(kps_2d), 17, 1))],
  271. axis=2)
  272. # load 3D keypoints
  273. with pycdf.CDF(
  274. join(subj_dir, 'MyPoseFeatures', 'D3_Positions_mono',
  275. basename + '.cdf')) as cdf:
  276. kps_3d = np.array(cdf['Pose'])
  277. kps_3d = kps_3d.reshape(
  278. (num_frames, 32, 3))[::self.sample_rate,
  279. self.movable_joints] / 1000.
  280. kps_3d = np.concatenate([kps_3d, np.ones((len(kps_3d), 17, 1))],
  281. axis=2)
  282. # calculate bounding boxes
  283. bboxes = np.stack([
  284. np.min(kps_2d[:, :, 0], axis=1),
  285. np.min(kps_2d[:, :, 1], axis=1),
  286. np.max(kps_2d[:, :, 0], axis=1),
  287. np.max(kps_2d[:, :, 1], axis=1)
  288. ],
  289. axis=1)
  290. centers = np.stack([(bboxes[:, 0] + bboxes[:, 2]) / 2,
  291. (bboxes[:, 1] + bboxes[:, 3]) / 2],
  292. axis=1)
  293. scales = self.scale_factor * np.max(
  294. bboxes[:, 2:] - bboxes[:, :2], axis=1) / 200
  295. # extract frames and save imgnames
  296. imgnames = []
  297. video_path = join(subj_dir, 'Videos', basename + '.mp4')
  298. sub_base = subject + '_' + basename.replace(' ', '_')
  299. img_dir = join(self.processed_dir, 'images', subject, sub_base)
  300. os.makedirs(img_dir, exist_ok=True)
  301. prefix = join(subject, sub_base, sub_base)
  302. cap = cv2.VideoCapture(video_path)
  303. i = 0
  304. while True:
  305. success, img = cap.read()
  306. if not success:
  307. break
  308. if i % self.sample_rate == 0:
  309. imgname = f'{prefix}_{i + 1:06d}.jpg'
  310. imgnames.append(imgname)
  311. dest_path = join(self.processed_dir, 'images', imgname)
  312. if not os.path.exists(dest_path):
  313. cv2.imwrite(dest_path, img)
  314. if len(imgnames) == len(centers):
  315. break
  316. i += 1
  317. cap.release()
  318. imgnames = np.array(imgnames)
  319. print(f'Annoatations for sequence "{subject} {basename}" are loaded. '
  320. f'{len(imgnames)} samples in total.')
  321. return imgnames, centers, scales, kps_2d, kps_3d
  322. def parse_args():
  323. parser = argparse.ArgumentParser()
  324. parser.add_argument(
  325. '--metadata', type=str, required=True, help='Path to metadata.xml')
  326. parser.add_argument(
  327. '--original',
  328. type=str,
  329. required=True,
  330. help='Directory of the original dataset with all files compressed. '
  331. 'Specifically, .tgz files belonging to subject 1 should be placed '
  332. 'under the subdirectory \"s1\".')
  333. parser.add_argument(
  334. '--extracted',
  335. type=str,
  336. default=None,
  337. help='Directory of the extracted files. If not given, it will be '
  338. 'placed under the same parent directory as original_dir.')
  339. parser.add_argument(
  340. '--processed',
  341. type=str,
  342. default=None,
  343. help='Directory of the processed files. If not given, it will be '
  344. 'placed under the same parent directory as original_dir.')
  345. parser.add_argument(
  346. '--sample-rate',
  347. type=int,
  348. default=5,
  349. help='Downsample FPS to `1 / sample_rate`. Default: 5.')
  350. args = parser.parse_args()
  351. return args
  352. if __name__ == '__main__':
  353. args = parse_args()
  354. h36m = PreprocessH36m(
  355. metadata=args.metadata,
  356. original_dir=args.original,
  357. extracted_dir=args.extracted,
  358. processed_dir=args.processed,
  359. sample_rate=args.sample_rate)
  360. h36m.extract_tgz()
  361. h36m.generate_cameras_file()
  362. h36m.generate_annotations()