pascal_voc.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os.path as osp
  4. import xml.etree.ElementTree as ET
  5. import numpy as np
  6. from mmengine.fileio import dump, list_from_file
  7. from mmengine.utils import mkdir_or_exist, track_progress
  8. from mmdet.evaluation import voc_classes
  9. label_ids = {name: i for i, name in enumerate(voc_classes())}
  10. def parse_xml(args):
  11. xml_path, img_path = args
  12. tree = ET.parse(xml_path)
  13. root = tree.getroot()
  14. size = root.find('size')
  15. w = int(size.find('width').text)
  16. h = int(size.find('height').text)
  17. bboxes = []
  18. labels = []
  19. bboxes_ignore = []
  20. labels_ignore = []
  21. for obj in root.findall('object'):
  22. name = obj.find('name').text
  23. label = label_ids[name]
  24. difficult = int(obj.find('difficult').text)
  25. bnd_box = obj.find('bndbox')
  26. bbox = [
  27. int(bnd_box.find('xmin').text),
  28. int(bnd_box.find('ymin').text),
  29. int(bnd_box.find('xmax').text),
  30. int(bnd_box.find('ymax').text)
  31. ]
  32. if difficult:
  33. bboxes_ignore.append(bbox)
  34. labels_ignore.append(label)
  35. else:
  36. bboxes.append(bbox)
  37. labels.append(label)
  38. if not bboxes:
  39. bboxes = np.zeros((0, 4))
  40. labels = np.zeros((0, ))
  41. else:
  42. bboxes = np.array(bboxes, ndmin=2) - 1
  43. labels = np.array(labels)
  44. if not bboxes_ignore:
  45. bboxes_ignore = np.zeros((0, 4))
  46. labels_ignore = np.zeros((0, ))
  47. else:
  48. bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1
  49. labels_ignore = np.array(labels_ignore)
  50. annotation = {
  51. 'filename': img_path,
  52. 'width': w,
  53. 'height': h,
  54. 'ann': {
  55. 'bboxes': bboxes.astype(np.float32),
  56. 'labels': labels.astype(np.int64),
  57. 'bboxes_ignore': bboxes_ignore.astype(np.float32),
  58. 'labels_ignore': labels_ignore.astype(np.int64)
  59. }
  60. }
  61. return annotation
  62. def cvt_annotations(devkit_path, years, split, out_file):
  63. if not isinstance(years, list):
  64. years = [years]
  65. annotations = []
  66. for year in years:
  67. filelist = osp.join(devkit_path,
  68. f'VOC{year}/ImageSets/Main/{split}.txt')
  69. if not osp.isfile(filelist):
  70. print(f'filelist does not exist: {filelist}, '
  71. f'skip voc{year} {split}')
  72. return
  73. img_names = list_from_file(filelist)
  74. xml_paths = [
  75. osp.join(devkit_path, f'VOC{year}/Annotations/{img_name}.xml')
  76. for img_name in img_names
  77. ]
  78. img_paths = [
  79. f'VOC{year}/JPEGImages/{img_name}.jpg' for img_name in img_names
  80. ]
  81. part_annotations = track_progress(parse_xml,
  82. list(zip(xml_paths, img_paths)))
  83. annotations.extend(part_annotations)
  84. if out_file.endswith('json'):
  85. annotations = cvt_to_coco_json(annotations)
  86. dump(annotations, out_file)
  87. return annotations
  88. def cvt_to_coco_json(annotations):
  89. image_id = 0
  90. annotation_id = 0
  91. coco = dict()
  92. coco['images'] = []
  93. coco['type'] = 'instance'
  94. coco['categories'] = []
  95. coco['annotations'] = []
  96. image_set = set()
  97. def addAnnItem(annotation_id, image_id, category_id, bbox, difficult_flag):
  98. annotation_item = dict()
  99. annotation_item['segmentation'] = []
  100. seg = []
  101. # bbox[] is x1,y1,x2,y2
  102. # left_top
  103. seg.append(int(bbox[0]))
  104. seg.append(int(bbox[1]))
  105. # left_bottom
  106. seg.append(int(bbox[0]))
  107. seg.append(int(bbox[3]))
  108. # right_bottom
  109. seg.append(int(bbox[2]))
  110. seg.append(int(bbox[3]))
  111. # right_top
  112. seg.append(int(bbox[2]))
  113. seg.append(int(bbox[1]))
  114. annotation_item['segmentation'].append(seg)
  115. xywh = np.array(
  116. [bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]])
  117. annotation_item['area'] = int(xywh[2] * xywh[3])
  118. if difficult_flag == 1:
  119. annotation_item['ignore'] = 0
  120. annotation_item['iscrowd'] = 1
  121. else:
  122. annotation_item['ignore'] = 0
  123. annotation_item['iscrowd'] = 0
  124. annotation_item['image_id'] = int(image_id)
  125. annotation_item['bbox'] = xywh.astype(int).tolist()
  126. annotation_item['category_id'] = int(category_id)
  127. annotation_item['id'] = int(annotation_id)
  128. coco['annotations'].append(annotation_item)
  129. return annotation_id + 1
  130. for category_id, name in enumerate(voc_classes()):
  131. category_item = dict()
  132. category_item['supercategory'] = str('none')
  133. category_item['id'] = int(category_id)
  134. category_item['name'] = str(name)
  135. coco['categories'].append(category_item)
  136. for ann_dict in annotations:
  137. file_name = ann_dict['filename']
  138. ann = ann_dict['ann']
  139. assert file_name not in image_set
  140. image_item = dict()
  141. image_item['id'] = int(image_id)
  142. image_item['file_name'] = str(file_name)
  143. image_item['height'] = int(ann_dict['height'])
  144. image_item['width'] = int(ann_dict['width'])
  145. coco['images'].append(image_item)
  146. image_set.add(file_name)
  147. bboxes = ann['bboxes'][:, :4]
  148. labels = ann['labels']
  149. for bbox_id in range(len(bboxes)):
  150. bbox = bboxes[bbox_id]
  151. label = labels[bbox_id]
  152. annotation_id = addAnnItem(
  153. annotation_id, image_id, label, bbox, difficult_flag=0)
  154. bboxes_ignore = ann['bboxes_ignore'][:, :4]
  155. labels_ignore = ann['labels_ignore']
  156. for bbox_id in range(len(bboxes_ignore)):
  157. bbox = bboxes_ignore[bbox_id]
  158. label = labels_ignore[bbox_id]
  159. annotation_id = addAnnItem(
  160. annotation_id, image_id, label, bbox, difficult_flag=1)
  161. image_id += 1
  162. return coco
  163. def parse_args():
  164. parser = argparse.ArgumentParser(
  165. description='Convert PASCAL VOC annotations to mmdetection format')
  166. parser.add_argument('devkit_path', help='pascal voc devkit path')
  167. parser.add_argument('-o', '--out-dir', help='output path')
  168. parser.add_argument(
  169. '--out-format',
  170. default='pkl',
  171. choices=('pkl', 'coco'),
  172. help='output format, "coco" indicates coco annotation format')
  173. args = parser.parse_args()
  174. return args
  175. def main():
  176. args = parse_args()
  177. devkit_path = args.devkit_path
  178. out_dir = args.out_dir if args.out_dir else devkit_path
  179. mkdir_or_exist(out_dir)
  180. years = []
  181. if osp.isdir(osp.join(devkit_path, 'VOC2007')):
  182. years.append('2007')
  183. if osp.isdir(osp.join(devkit_path, 'VOC2012')):
  184. years.append('2012')
  185. if '2007' in years and '2012' in years:
  186. years.append(['2007', '2012'])
  187. if not years:
  188. raise IOError(f'The devkit path {devkit_path} contains neither '
  189. '"VOC2007" nor "VOC2012" subfolder')
  190. out_fmt = f'.{args.out_format}'
  191. if args.out_format == 'coco':
  192. out_fmt = '.json'
  193. for year in years:
  194. if year == '2007':
  195. prefix = 'voc07'
  196. elif year == '2012':
  197. prefix = 'voc12'
  198. elif year == ['2007', '2012']:
  199. prefix = 'voc0712'
  200. for split in ['train', 'val', 'trainval']:
  201. dataset_name = prefix + '_' + split
  202. print(f'processing {dataset_name} ...')
  203. cvt_annotations(devkit_path, year, split,
  204. osp.join(out_dir, dataset_name + out_fmt))
  205. if not isinstance(year, list):
  206. dataset_name = prefix + '_test'
  207. print(f'processing {dataset_name} ...')
  208. cvt_annotations(devkit_path, year, 'test',
  209. osp.join(out_dir, dataset_name + out_fmt))
  210. print('Done!')
  211. if __name__ == '__main__':
  212. main()