parse_cofw_dataset.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import json
  3. import os
  4. import time
  5. import cv2
  6. import h5py
  7. import numpy as np
  8. mat_files = ['COFW_train_color.mat', 'COFW_test_color.mat']
  9. dataset_dir = 'data/cofw/'
  10. image_root = os.path.join(dataset_dir, 'images/')
  11. annotation_root = os.path.join(dataset_dir, 'annotations/')
  12. os.makedirs(image_root, exist_ok=True)
  13. os.makedirs(annotation_root, exist_ok=True)
  14. cnt = 0
  15. for mat_file in mat_files:
  16. mat = h5py.File(os.path.join(dataset_dir, mat_file), 'r')
  17. if 'train' in mat_file:
  18. imgs = mat['IsTr']
  19. pts = mat['phisTr']
  20. bboxes = mat['bboxesTr']
  21. is_train = True
  22. json_file = 'cofw_train.json'
  23. else:
  24. imgs = mat['IsT']
  25. pts = mat['phisT']
  26. bboxes = mat['bboxesT']
  27. is_train = False
  28. json_file = 'cofw_test.json'
  29. images = []
  30. annotations = []
  31. num = pts.shape[1]
  32. for idx in range(0, num):
  33. cnt += 1
  34. img = np.array(mat[imgs[0, idx]]).transpose()
  35. keypoints = pts[:, idx].reshape(3, -1).transpose()
  36. # 2 for valid and 1 for occlusion
  37. keypoints[:, 2] = 2 - keypoints[:, 2]
  38. # matlab 1-index to python 0-index
  39. keypoints[:, :2] -= 1
  40. bbox = bboxes[:, idx]
  41. # check nonnegativity
  42. bbox[bbox < 0] = 0
  43. keypoints[keypoints < 0] = 0
  44. image = {}
  45. image['id'] = cnt
  46. image['file_name'] = f'{str(cnt).zfill(6)}.jpg'
  47. image['height'] = img.shape[0]
  48. image['width'] = img.shape[1]
  49. cv2.imwrite(
  50. os.path.join(image_root, image['file_name']),
  51. cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
  52. images.append(image)
  53. anno = {}
  54. anno['keypoints'] = keypoints.reshape(-1).tolist()
  55. anno['image_id'] = cnt
  56. anno['id'] = cnt
  57. anno['num_keypoints'] = len(keypoints) # all keypoints are labelled
  58. anno['bbox'] = bbox.tolist()
  59. anno['iscrowd'] = 0
  60. anno['area'] = anno['bbox'][2] * anno['bbox'][3]
  61. anno['category_id'] = 1
  62. annotations.append(anno)
  63. cocotype = {}
  64. cocotype['info'] = {}
  65. cocotype['info']['description'] = 'COFW Generated by MMPose Team'
  66. cocotype['info']['version'] = '1.0'
  67. cocotype['info']['year'] = time.strftime('%Y', time.localtime())
  68. cocotype['info']['date_created'] = time.strftime('%Y/%m/%d',
  69. time.localtime())
  70. cocotype['images'] = images
  71. cocotype['annotations'] = annotations
  72. cocotype['categories'] = [{
  73. 'supercategory': 'person',
  74. 'id': 1,
  75. 'name': 'face',
  76. 'keypoints': [],
  77. 'skeleton': []
  78. }]
  79. ann_path = os.path.join(annotation_root, json_file)
  80. json.dump(cocotype, open(ann_path, 'w'))
  81. print(f'done {ann_path}')