parse_animalpose_dataset.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import json
  3. import os
  4. import re
  5. import time
  6. import warnings
  7. import cv2
  8. import numpy as np
  9. import xmltodict
  10. from xtcocotools.coco import COCO
  11. np.random.seed(0)
  12. def list_all_files(root_dir, ext='.xml'):
  13. """List all files in the root directory and all its sub directories.
  14. :param root_dir: root directory
  15. :param ext: filename extension
  16. :return: list of files
  17. """
  18. files = []
  19. file_list = os.listdir(root_dir)
  20. for i in range(0, len(file_list)):
  21. path = os.path.join(root_dir, file_list[i])
  22. if os.path.isdir(path):
  23. files.extend(list_all_files(path))
  24. if os.path.isfile(path):
  25. if path.lower().endswith(ext):
  26. files.append(path)
  27. return files
  28. def get_anno_info():
  29. keypoints_info = [
  30. 'L_Eye',
  31. 'R_Eye',
  32. 'L_EarBase',
  33. 'R_EarBase',
  34. 'Nose',
  35. 'Throat',
  36. 'TailBase',
  37. 'Withers',
  38. 'L_F_Elbow',
  39. 'R_F_Elbow',
  40. 'L_B_Elbow',
  41. 'R_B_Elbow',
  42. 'L_F_Knee',
  43. 'R_F_Knee',
  44. 'L_B_Knee',
  45. 'R_B_Knee',
  46. 'L_F_Paw',
  47. 'R_F_Paw',
  48. 'L_B_Paw',
  49. 'R_B_Paw',
  50. ]
  51. skeleton_info = [[1, 2], [1, 3], [2, 4], [1, 5], [2, 5], [5, 6], [6, 8],
  52. [7, 8], [6, 9], [9, 13], [13, 17], [6, 10], [10, 14],
  53. [14, 18], [7, 11], [11, 15], [15, 19], [7, 12], [12, 16],
  54. [16, 20]]
  55. category_info = [{
  56. 'supercategory': 'animal',
  57. 'id': 1,
  58. 'name': 'animal',
  59. 'keypoints': keypoints_info,
  60. 'skeleton': skeleton_info
  61. }]
  62. return keypoints_info, skeleton_info, category_info
  63. def xml2coco_trainval(file_list, img_root, save_path, start_ann_id=0):
  64. """Save annotations in coco-format.
  65. :param file_list: list of data annotation files.
  66. :param img_root: the root dir to load images.
  67. :param save_path: the path to save transformed annotation file.
  68. :param start_ann_id: the starting point to count the annotation id.
  69. :param val_num: the number of annotated objects for validation.
  70. """
  71. images = []
  72. annotations = []
  73. img_ids = []
  74. ann_ids = []
  75. ann_id = start_ann_id
  76. name2id = {
  77. 'L_Eye': 0,
  78. 'R_Eye': 1,
  79. 'L_EarBase': 2,
  80. 'R_EarBase': 3,
  81. 'Nose': 4,
  82. 'Throat': 5,
  83. 'TailBase': 6,
  84. 'Withers': 7,
  85. 'L_F_Elbow': 8,
  86. 'R_F_Elbow': 9,
  87. 'L_B_Elbow': 10,
  88. 'R_B_Elbow': 11,
  89. 'L_F_Knee': 12,
  90. 'R_F_Knee': 13,
  91. 'L_B_Knee': 14,
  92. 'R_B_Knee': 15,
  93. 'L_F_Paw': 16,
  94. 'R_F_Paw': 17,
  95. 'L_B_Paw': 18,
  96. 'R_B_Paw': 19
  97. }
  98. for file in file_list:
  99. data_anno = xmltodict.parse(open(file).read())['annotation']
  100. img_id = int(data_anno['image'].split('_')[0] +
  101. data_anno['image'].split('_')[1])
  102. if img_id not in img_ids:
  103. image_name = 'VOC2012/JPEGImages/' + data_anno['image'] + '.jpg'
  104. img = cv2.imread(os.path.join(img_root, image_name))
  105. image = {}
  106. image['id'] = img_id
  107. image['file_name'] = image_name
  108. image['height'] = img.shape[0]
  109. image['width'] = img.shape[1]
  110. images.append(image)
  111. img_ids.append(img_id)
  112. else:
  113. pass
  114. keypoint_anno = data_anno['keypoints']['keypoint']
  115. assert len(keypoint_anno) == 20
  116. keypoints = np.zeros([20, 3], dtype=np.float32)
  117. for kpt_anno in keypoint_anno:
  118. keypoint_name = kpt_anno['@name']
  119. keypoint_id = name2id[keypoint_name]
  120. visibility = int(kpt_anno['@visible'])
  121. if visibility == 0:
  122. continue
  123. else:
  124. keypoints[keypoint_id, 0] = float(kpt_anno['@x'])
  125. keypoints[keypoint_id, 1] = float(kpt_anno['@y'])
  126. keypoints[keypoint_id, 2] = 2
  127. anno = {}
  128. anno['keypoints'] = keypoints.reshape(-1).tolist()
  129. anno['image_id'] = img_id
  130. anno['id'] = ann_id
  131. anno['num_keypoints'] = int(sum(keypoints[:, 2] > 0))
  132. visible_bounds = data_anno['visible_bounds']
  133. anno['bbox'] = [
  134. float(visible_bounds['@xmin']),
  135. float(visible_bounds['@ymin']),
  136. float(visible_bounds['@width']),
  137. float(visible_bounds['@height'])
  138. ]
  139. anno['iscrowd'] = 0
  140. anno['area'] = float(anno['bbox'][2] * anno['bbox'][3])
  141. anno['category_id'] = 1
  142. annotations.append(anno)
  143. ann_ids.append(ann_id)
  144. ann_id += 1
  145. cocotype = {}
  146. cocotype['info'] = {}
  147. cocotype['info'][
  148. 'description'] = 'AnimalPose dataset Generated by MMPose Team'
  149. cocotype['info']['version'] = '1.0'
  150. cocotype['info']['year'] = time.strftime('%Y', time.localtime())
  151. cocotype['info']['date_created'] = time.strftime('%Y/%m/%d',
  152. time.localtime())
  153. cocotype['images'] = images
  154. cocotype['annotations'] = annotations
  155. keypoints_info, skeleton_info, category_info = get_anno_info()
  156. cocotype['categories'] = category_info
  157. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  158. json.dump(cocotype, open(save_path, 'w'), indent=4)
  159. print('number of images:', len(img_ids))
  160. print('number of annotations:', len(ann_ids))
  161. print(f'done {save_path}')
  162. def xml2coco_test(file_list, img_root, save_path, start_ann_id=0):
  163. """Save annotations in coco-format.
  164. :param file_list: list of data annotation files.
  165. :param img_root: the root dir to load images.
  166. :param save_path: the path to save transformed annotation file.
  167. :param start_ann_id: the starting point to count the annotation id.
  168. """
  169. images = []
  170. annotations = []
  171. img_ids = []
  172. ann_ids = []
  173. ann_id = start_ann_id
  174. name2id = {
  175. 'L_eye': 0,
  176. 'R_eye': 1,
  177. 'L_ear': 2,
  178. 'R_ear': 3,
  179. 'Nose': 4,
  180. 'Throat': 5,
  181. 'Tail': 6,
  182. 'withers': 7,
  183. 'L_F_elbow': 8,
  184. 'R_F_elbow': 9,
  185. 'L_B_elbow': 10,
  186. 'R_B_elbow': 11,
  187. 'L_F_knee': 12,
  188. 'R_F_knee': 13,
  189. 'L_B_knee': 14,
  190. 'R_B_knee': 15,
  191. 'L_F_paw': 16,
  192. 'R_F_paw': 17,
  193. 'L_B_paw': 18,
  194. 'R_B_paw': 19
  195. }
  196. cat2id = {'cat': 1, 'cow': 2, 'dog': 3, 'horse': 4, 'sheep': 5}
  197. for file in file_list:
  198. data_anno = xmltodict.parse(open(file).read())['annotation']
  199. category_id = cat2id[data_anno['category']]
  200. img_id = category_id * 1000 + int(
  201. re.findall(r'\d+', data_anno['image'])[0])
  202. assert img_id not in img_ids
  203. # prepare images
  204. image_name = os.path.join('animalpose_image_part2',
  205. data_anno['category'], data_anno['image'])
  206. img = cv2.imread(os.path.join(img_root, image_name))
  207. image = {}
  208. image['id'] = img_id
  209. image['file_name'] = image_name
  210. image['height'] = img.shape[0]
  211. image['width'] = img.shape[1]
  212. images.append(image)
  213. img_ids.append(img_id)
  214. # prepare annotations
  215. keypoint_anno = data_anno['keypoints']['keypoint']
  216. keypoints = np.zeros([20, 3], dtype=np.float32)
  217. for kpt_anno in keypoint_anno:
  218. keypoint_name = kpt_anno['@name']
  219. keypoint_id = name2id[keypoint_name]
  220. visibility = int(kpt_anno['@visible'])
  221. if visibility == 0:
  222. continue
  223. else:
  224. keypoints[keypoint_id, 0] = float(kpt_anno['@x'])
  225. keypoints[keypoint_id, 1] = float(kpt_anno['@y'])
  226. keypoints[keypoint_id, 2] = 2
  227. anno = {}
  228. anno['keypoints'] = keypoints.reshape(-1).tolist()
  229. anno['image_id'] = img_id
  230. anno['id'] = ann_id
  231. anno['num_keypoints'] = int(sum(keypoints[:, 2] > 0))
  232. visible_bounds = data_anno['visible_bounds']
  233. anno['bbox'] = [
  234. float(visible_bounds['@xmin']),
  235. float(visible_bounds['@xmax']
  236. ), # typo in original xml: should be 'ymin'
  237. float(visible_bounds['@width']),
  238. float(visible_bounds['@height'])
  239. ]
  240. anno['iscrowd'] = 0
  241. anno['area'] = float(anno['bbox'][2] * anno['bbox'][3])
  242. anno['category_id'] = 1
  243. annotations.append(anno)
  244. ann_ids.append(ann_id)
  245. ann_id += 1
  246. cocotype = {}
  247. cocotype['info'] = {}
  248. cocotype['info'][
  249. 'description'] = 'AnimalPose dataset Generated by MMPose Team'
  250. cocotype['info']['version'] = '1.0'
  251. cocotype['info']['year'] = time.strftime('%Y', time.localtime())
  252. cocotype['info']['date_created'] = time.strftime('%Y/%m/%d',
  253. time.localtime())
  254. cocotype['images'] = images
  255. cocotype['annotations'] = annotations
  256. keypoints_info, skeleton_info, category_info = get_anno_info()
  257. cocotype['categories'] = category_info
  258. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  259. json.dump(cocotype, open(save_path, 'w'), indent=4)
  260. print('=========================================================')
  261. print('number of images:', len(img_ids))
  262. print('number of annotations:', len(ann_ids))
  263. print(f'done {save_path}')
  264. def split_train_val(work_dir, trainval_file, train_file, val_file,
  265. val_ann_num):
  266. """Split train-val json file into training and validation files.
  267. :param work_dir: path to load train-val json file, and save split files.
  268. :param trainval_file: The input json file combining both train and val.
  269. :param trainval_file: The output json file for training.
  270. :param trainval_file: The output json file for validation.
  271. :param val_ann_num: the number of validation annotations.
  272. """
  273. coco = COCO(os.path.join(work_dir, trainval_file))
  274. img_list = list(coco.imgs.keys())
  275. np.random.shuffle(img_list)
  276. count = 0
  277. images_train = []
  278. images_val = []
  279. annotations_train = []
  280. annotations_val = []
  281. for img_id in img_list:
  282. ann_ids = coco.getAnnIds(img_id)
  283. if count + len(ann_ids) <= val_ann_num:
  284. # for validation
  285. count += len(ann_ids)
  286. images_val.append(coco.imgs[img_id])
  287. for ann_id in ann_ids:
  288. annotations_val.append(coco.anns[ann_id])
  289. else:
  290. images_train.append(coco.imgs[img_id])
  291. for ann_id in ann_ids:
  292. annotations_train.append(coco.anns[ann_id])
  293. if count == val_ann_num:
  294. print(f'We have found {count} annotations for validation.')
  295. else:
  296. warnings.warn(
  297. f'We only found {count} annotations, instead of {val_ann_num}.')
  298. cocotype_train = {}
  299. cocotype_val = {}
  300. keypoints_info, skeleton_info, category_info = get_anno_info()
  301. cocotype_train['info'] = {}
  302. cocotype_train['info'][
  303. 'description'] = 'AnimalPose dataset Generated by MMPose Team'
  304. cocotype_train['info']['version'] = '1.0'
  305. cocotype_train['info']['year'] = time.strftime('%Y', time.localtime())
  306. cocotype_train['info']['date_created'] = time.strftime(
  307. '%Y/%m/%d', time.localtime())
  308. cocotype_train['images'] = images_train
  309. cocotype_train['annotations'] = annotations_train
  310. cocotype_train['categories'] = category_info
  311. json.dump(
  312. cocotype_train,
  313. open(os.path.join(work_dir, train_file), 'w'),
  314. indent=4)
  315. print('=========================================================')
  316. print('number of images:', len(images_train))
  317. print('number of annotations:', len(annotations_train))
  318. print(f'done {train_file}')
  319. cocotype_val['info'] = {}
  320. cocotype_val['info'][
  321. 'description'] = 'AnimalPose dataset Generated by MMPose Team'
  322. cocotype_val['info']['version'] = '1.0'
  323. cocotype_val['info']['year'] = time.strftime('%Y', time.localtime())
  324. cocotype_val['info']['date_created'] = time.strftime(
  325. '%Y/%m/%d', time.localtime())
  326. cocotype_val['images'] = images_val
  327. cocotype_val['annotations'] = annotations_val
  328. cocotype_val['categories'] = category_info
  329. json.dump(
  330. cocotype_val, open(os.path.join(work_dir, val_file), 'w'), indent=4)
  331. print('=========================================================')
  332. print('number of images:', len(images_val))
  333. print('number of annotations:', len(annotations_val))
  334. print(f'done {val_file}')
  335. dataset_dir = 'data/animalpose/'
  336. # We choose the images from PascalVOC for train + val
  337. # In total, train+val: 3608 images, 5117 annotations
  338. xml2coco_trainval(
  339. list_all_files(os.path.join(dataset_dir, 'PASCAL2011_animal_annotation')),
  340. dataset_dir,
  341. os.path.join(dataset_dir, 'annotations', 'animalpose_trainval.json'),
  342. start_ann_id=1000000)
  343. # train: 2798 images, 4000 annotations
  344. # val: 810 images, 1117 annotations
  345. split_train_val(
  346. os.path.join(dataset_dir, 'annotations'),
  347. 'animalpose_trainval.json',
  348. 'animalpose_train.json',
  349. 'animalpose_val.json',
  350. val_ann_num=1117)
  351. # We choose the remaining 1000 images for test
  352. # 1000 images, 1000 annotations
  353. xml2coco_test(
  354. list_all_files(os.path.join(dataset_dir, 'animalpose_anno2')),
  355. dataset_dir,
  356. os.path.join(dataset_dir, 'annotations', 'animalpose_test.json'),
  357. start_ann_id=0)