parse_deepposekit_dataset.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import json
  3. import os
  4. import time
  5. import cv2
  6. import h5py
  7. import numpy as np
  8. np.random.seed(0)
  9. def save_coco_anno(keypoints_all,
  10. annotated_all,
  11. imgs_all,
  12. keypoints_info,
  13. skeleton_info,
  14. dataset,
  15. img_root,
  16. save_path,
  17. start_img_id=0,
  18. start_ann_id=0):
  19. """Save annotations in coco-format.
  20. :param keypoints_all: keypoint annotations.
  21. :param annotated_all: images annotated or not.
  22. :param imgs_all: the array of images.
  23. :param keypoints_info: information about keypoint name.
  24. :param skeleton_info: information about skeleton connection.
  25. :param dataset: information about dataset name.
  26. :param img_root: the path to save images.
  27. :param save_path: the path to save transformed annotation file.
  28. :param start_img_id: the starting point to count the image id.
  29. :param start_ann_id: the starting point to count the annotation id.
  30. """
  31. images = []
  32. annotations = []
  33. img_id = start_img_id
  34. ann_id = start_ann_id
  35. num_annotations, keypoints_num, _ = keypoints_all.shape
  36. for i in range(num_annotations):
  37. img = imgs_all[i]
  38. keypoints = np.concatenate(
  39. [keypoints_all[i], annotated_all[i][:, None] * 2], axis=1)
  40. min_x, min_y = np.min(keypoints[keypoints[:, 2] > 0, :2], axis=0)
  41. max_x, max_y = np.max(keypoints[keypoints[:, 2] > 0, :2], axis=0)
  42. anno = {}
  43. anno['keypoints'] = keypoints.reshape(-1).tolist()
  44. anno['image_id'] = img_id
  45. anno['id'] = ann_id
  46. anno['num_keypoints'] = int(sum(keypoints[:, 2] > 0))
  47. anno['bbox'] = [
  48. float(min_x),
  49. float(min_y),
  50. float(max_x - min_x + 1),
  51. float(max_y - min_y + 1)
  52. ]
  53. anno['iscrowd'] = 0
  54. anno['area'] = anno['bbox'][2] * anno['bbox'][3]
  55. anno['category_id'] = 1
  56. annotations.append(anno)
  57. ann_id += 1
  58. image = {}
  59. image['id'] = img_id
  60. image['file_name'] = f'{img_id}.jpg'
  61. image['height'] = img.shape[0]
  62. image['width'] = img.shape[1]
  63. images.append(image)
  64. img_id += 1
  65. cv2.imwrite(os.path.join(img_root, image['file_name']), img)
  66. skeleton = np.concatenate(
  67. [np.arange(keypoints_num)[:, None], skeleton_info[:, 0][:, None]],
  68. axis=1) + 1
  69. skeleton = skeleton[skeleton.min(axis=1) > 0]
  70. cocotype = {}
  71. cocotype['info'] = {}
  72. cocotype['info'][
  73. 'description'] = 'DeepPoseKit-Data Generated by MMPose Team'
  74. cocotype['info']['version'] = '1.0'
  75. cocotype['info']['year'] = time.strftime('%Y', time.localtime())
  76. cocotype['info']['date_created'] = time.strftime('%Y/%m/%d',
  77. time.localtime())
  78. cocotype['images'] = images
  79. cocotype['annotations'] = annotations
  80. cocotype['categories'] = [{
  81. 'supercategory': 'animal',
  82. 'id': 1,
  83. 'name': dataset,
  84. 'keypoints': keypoints_info,
  85. 'skeleton': skeleton.tolist()
  86. }]
  87. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  88. json.dump(cocotype, open(save_path, 'w'), indent=4)
  89. print('number of images:', img_id)
  90. print('number of annotations:', ann_id)
  91. print(f'done {save_path}')
  92. for dataset in ['fly', 'locust', 'zebra']:
  93. keypoints_info = []
  94. if dataset == 'fly':
  95. keypoints_info = [
  96. 'head', 'eyeL', 'eyeR', 'neck', 'thorax', 'abdomen', 'forelegR1',
  97. 'forelegR2', 'forelegR3', 'forelegR4', 'midlegR1', 'midlegR2',
  98. 'midlegR3', 'midlegR4', 'hindlegR1', 'hindlegR2', 'hindlegR3',
  99. 'hindlegR4', 'forelegL1', 'forelegL2', 'forelegL3', 'forelegL4',
  100. 'midlegL1', 'midlegL2', 'midlegL3', 'midlegL4', 'hindlegL1',
  101. 'hindlegL2', 'hindlegL3', 'hindlegL4', 'wingL', 'wingR'
  102. ]
  103. elif dataset == 'locust':
  104. keypoints_info = [
  105. 'head', 'neck', 'thorax', 'abdomen1', 'abdomen2', 'anttipL',
  106. 'antbaseL', 'eyeL', 'forelegL1', 'forelegL2', 'forelegL3',
  107. 'forelegL4', 'midlegL1', 'midlegL2', 'midlegL3', 'midlegL4',
  108. 'hindlegL1', 'hindlegL2', 'hindlegL3', 'hindlegL4', 'anttipR',
  109. 'antbaseR', 'eyeR', 'forelegR1', 'forelegR2', 'forelegR3',
  110. 'forelegR4', 'midlegR1', 'midlegR2', 'midlegR3', 'midlegR4',
  111. 'hindlegR1', 'hindlegR2', 'hindlegR3', 'hindlegR4'
  112. ]
  113. elif dataset == 'zebra':
  114. keypoints_info = [
  115. 'snout', 'head', 'neck', 'forelegL1', 'forelegR1', 'hindlegL1',
  116. 'hindlegR1', 'tailbase', 'tailtip'
  117. ]
  118. else:
  119. NotImplementedError()
  120. dataset_dir = f'data/DeepPoseKit-Data/datasets/{dataset}'
  121. with h5py.File(
  122. os.path.join(dataset_dir, 'annotation_data_release.h5'), 'r') as f:
  123. # List all groups
  124. annotations = np.array(f['annotations'])
  125. annotated = np.array(f['annotated'])
  126. images = np.array(f['images'])
  127. skeleton_info = np.array(f['skeleton'])
  128. annotation_num, kpt_num, _ = annotations.shape
  129. data_list = np.arange(0, annotation_num)
  130. np.random.shuffle(data_list)
  131. val_data_num = annotation_num // 10
  132. train_data_num = annotation_num - val_data_num
  133. train_list = data_list[0:train_data_num]
  134. val_list = data_list[train_data_num:]
  135. img_root = os.path.join(dataset_dir, 'images')
  136. os.makedirs(img_root, exist_ok=True)
  137. save_coco_anno(
  138. annotations[train_list], annotated[train_list], images[train_list],
  139. keypoints_info, skeleton_info, dataset, img_root,
  140. os.path.join(dataset_dir, 'annotations', f'{dataset}_train.json'))
  141. save_coco_anno(
  142. annotations[val_list],
  143. annotated[val_list],
  144. images[val_list],
  145. keypoints_info,
  146. skeleton_info,
  147. dataset,
  148. img_root,
  149. os.path.join(dataset_dir, 'annotations', f'{dataset}_test.json'),
  150. start_img_id=train_data_num,
  151. start_ann_id=train_data_num)