cityscapes.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import glob
  4. import os.path as osp
  5. import cityscapesscripts.helpers.labels as CSLabels
  6. import mmcv
  7. import numpy as np
  8. import pycocotools.mask as maskUtils
  9. from mmengine.fileio import dump
  10. from mmengine.utils import (Timer, mkdir_or_exist, track_parallel_progress,
  11. track_progress)
  12. def collect_files(img_dir, gt_dir):
  13. suffix = 'leftImg8bit.png'
  14. files = []
  15. for img_file in glob.glob(osp.join(img_dir, '**/*.png')):
  16. assert img_file.endswith(suffix), img_file
  17. inst_file = gt_dir + img_file[
  18. len(img_dir):-len(suffix)] + 'gtFine_instanceIds.png'
  19. # Note that labelIds are not converted to trainId for seg map
  20. segm_file = gt_dir + img_file[
  21. len(img_dir):-len(suffix)] + 'gtFine_labelIds.png'
  22. files.append((img_file, inst_file, segm_file))
  23. assert len(files), f'No images found in {img_dir}'
  24. print(f'Loaded {len(files)} images from {img_dir}')
  25. return files
  26. def collect_annotations(files, nproc=1):
  27. print('Loading annotation images')
  28. if nproc > 1:
  29. images = track_parallel_progress(load_img_info, files, nproc=nproc)
  30. else:
  31. images = track_progress(load_img_info, files)
  32. return images
  33. def load_img_info(files):
  34. img_file, inst_file, segm_file = files
  35. inst_img = mmcv.imread(inst_file, 'unchanged')
  36. # ids < 24 are stuff labels (filtering them first is about 5% faster)
  37. unique_inst_ids = np.unique(inst_img[inst_img >= 24])
  38. anno_info = []
  39. for inst_id in unique_inst_ids:
  40. # For non-crowd annotations, inst_id // 1000 is the label_id
  41. # Crowd annotations have <1000 instance ids
  42. label_id = inst_id // 1000 if inst_id >= 1000 else inst_id
  43. label = CSLabels.id2label[label_id]
  44. if not label.hasInstances or label.ignoreInEval:
  45. continue
  46. category_id = label.id
  47. iscrowd = int(inst_id < 1000)
  48. mask = np.asarray(inst_img == inst_id, dtype=np.uint8, order='F')
  49. mask_rle = maskUtils.encode(mask[:, :, None])[0]
  50. area = maskUtils.area(mask_rle)
  51. # convert to COCO style XYWH format
  52. bbox = maskUtils.toBbox(mask_rle)
  53. # for json encoding
  54. mask_rle['counts'] = mask_rle['counts'].decode()
  55. anno = dict(
  56. iscrowd=iscrowd,
  57. category_id=category_id,
  58. bbox=bbox.tolist(),
  59. area=area.tolist(),
  60. segmentation=mask_rle)
  61. anno_info.append(anno)
  62. video_name = osp.basename(osp.dirname(img_file))
  63. img_info = dict(
  64. # remove img_prefix for filename
  65. file_name=osp.join(video_name, osp.basename(img_file)),
  66. height=inst_img.shape[0],
  67. width=inst_img.shape[1],
  68. anno_info=anno_info,
  69. segm_file=osp.join(video_name, osp.basename(segm_file)))
  70. return img_info
  71. def cvt_annotations(image_infos, out_json_name):
  72. out_json = dict()
  73. img_id = 0
  74. ann_id = 0
  75. out_json['images'] = []
  76. out_json['categories'] = []
  77. out_json['annotations'] = []
  78. for image_info in image_infos:
  79. image_info['id'] = img_id
  80. anno_infos = image_info.pop('anno_info')
  81. out_json['images'].append(image_info)
  82. for anno_info in anno_infos:
  83. anno_info['image_id'] = img_id
  84. anno_info['id'] = ann_id
  85. out_json['annotations'].append(anno_info)
  86. ann_id += 1
  87. img_id += 1
  88. for label in CSLabels.labels:
  89. if label.hasInstances and not label.ignoreInEval:
  90. cat = dict(id=label.id, name=label.name)
  91. out_json['categories'].append(cat)
  92. if len(out_json['annotations']) == 0:
  93. out_json.pop('annotations')
  94. dump(out_json, out_json_name)
  95. return out_json
  96. def parse_args():
  97. parser = argparse.ArgumentParser(
  98. description='Convert Cityscapes annotations to COCO format')
  99. parser.add_argument('cityscapes_path', help='cityscapes data path')
  100. parser.add_argument('--img-dir', default='leftImg8bit', type=str)
  101. parser.add_argument('--gt-dir', default='gtFine', type=str)
  102. parser.add_argument('-o', '--out-dir', help='output path')
  103. parser.add_argument(
  104. '--nproc', default=1, type=int, help='number of process')
  105. args = parser.parse_args()
  106. return args
  107. def main():
  108. args = parse_args()
  109. cityscapes_path = args.cityscapes_path
  110. out_dir = args.out_dir if args.out_dir else cityscapes_path
  111. mkdir_or_exist(out_dir)
  112. img_dir = osp.join(cityscapes_path, args.img_dir)
  113. gt_dir = osp.join(cityscapes_path, args.gt_dir)
  114. set_name = dict(
  115. train='instancesonly_filtered_gtFine_train.json',
  116. val='instancesonly_filtered_gtFine_val.json',
  117. test='instancesonly_filtered_gtFine_test.json')
  118. for split, json_name in set_name.items():
  119. print(f'Converting {split} into {json_name}')
  120. with Timer(print_tmpl='It took {}s to convert Cityscapes annotation'):
  121. files = collect_files(
  122. osp.join(img_dir, split), osp.join(gt_dir, split))
  123. image_infos = collect_annotations(files, nproc=args.nproc)
  124. cvt_annotations(image_infos, osp.join(out_dir, json_name))
  125. if __name__ == '__main__':
  126. main()