split_coco.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os.path as osp
  4. import numpy as np
  5. from mmengine.fileio import dump, load
  6. from mmengine.utils import mkdir_or_exist, track_parallel_progress
  7. prog_description = '''K-Fold coco split.
  8. To split coco data for semi-supervised object detection:
  9. python tools/misc/split_coco.py
  10. '''
  11. def parse_args():
  12. parser = argparse.ArgumentParser()
  13. parser.add_argument(
  14. '--data-root',
  15. type=str,
  16. help='The data root of coco dataset.',
  17. default='./data/coco/')
  18. parser.add_argument(
  19. '--out-dir',
  20. type=str,
  21. help='The output directory of coco semi-supervised annotations.',
  22. default='./data/coco/semi_anns/')
  23. parser.add_argument(
  24. '--labeled-percent',
  25. type=float,
  26. nargs='+',
  27. help='The percentage of labeled data in the training set.',
  28. default=[1, 2, 5, 10])
  29. parser.add_argument(
  30. '--fold',
  31. type=int,
  32. help='K-fold cross validation for semi-supervised object detection.',
  33. default=5)
  34. args = parser.parse_args()
  35. return args
  36. def split_coco(data_root, out_dir, percent, fold):
  37. """Split COCO data for Semi-supervised object detection.
  38. Args:
  39. data_root (str): The data root of coco dataset.
  40. out_dir (str): The output directory of coco semi-supervised
  41. annotations.
  42. percent (float): The percentage of labeled data in the training set.
  43. fold (int): The fold of dataset and set as random seed for data split.
  44. """
  45. def save_anns(name, images, annotations):
  46. sub_anns = dict()
  47. sub_anns['images'] = images
  48. sub_anns['annotations'] = annotations
  49. sub_anns['licenses'] = anns['licenses']
  50. sub_anns['categories'] = anns['categories']
  51. sub_anns['info'] = anns['info']
  52. mkdir_or_exist(out_dir)
  53. dump(sub_anns, f'{out_dir}/{name}.json')
  54. # set random seed with the fold
  55. np.random.seed(fold)
  56. ann_file = osp.join(data_root, 'annotations/instances_train2017.json')
  57. anns = load(ann_file)
  58. image_list = anns['images']
  59. labeled_total = int(percent / 100. * len(image_list))
  60. labeled_inds = set(
  61. np.random.choice(range(len(image_list)), size=labeled_total))
  62. labeled_ids, labeled_images, unlabeled_images = [], [], []
  63. for i in range(len(image_list)):
  64. if i in labeled_inds:
  65. labeled_images.append(image_list[i])
  66. labeled_ids.append(image_list[i]['id'])
  67. else:
  68. unlabeled_images.append(image_list[i])
  69. # get all annotations of labeled images
  70. labeled_ids = set(labeled_ids)
  71. labeled_annotations, unlabeled_annotations = [], []
  72. for ann in anns['annotations']:
  73. if ann['image_id'] in labeled_ids:
  74. labeled_annotations.append(ann)
  75. else:
  76. unlabeled_annotations.append(ann)
  77. # save labeled and unlabeled
  78. labeled_name = f'instances_train2017.{fold}@{percent}'
  79. unlabeled_name = f'instances_train2017.{fold}@{percent}-unlabeled'
  80. save_anns(labeled_name, labeled_images, labeled_annotations)
  81. save_anns(unlabeled_name, unlabeled_images, unlabeled_annotations)
  82. def multi_wrapper(args):
  83. return split_coco(*args)
  84. if __name__ == '__main__':
  85. args = parse_args()
  86. arguments_list = [(args.data_root, args.out_dir, p, f)
  87. for f in range(1, args.fold + 1)
  88. for p in args.labeled_percent]
  89. track_parallel_progress(multi_wrapper, arguments_list, args.fold)