openimages.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import csv
  3. import os.path as osp
  4. from collections import defaultdict
  5. from typing import Dict, List, Optional
  6. import numpy as np
  7. from mmengine.fileio import get_local_path, load
  8. from mmengine.utils import is_abs
  9. from mmdet.registry import DATASETS
  10. from .base_det_dataset import BaseDetDataset
  11. @DATASETS.register_module()
  12. class OpenImagesDataset(BaseDetDataset):
  13. """Open Images dataset for detection.
  14. Args:
  15. ann_file (str): Annotation file path.
  16. label_file (str): File path of the label description file that
  17. maps the classes names in MID format to their short
  18. descriptions.
  19. meta_file (str): File path to get image metas.
  20. hierarchy_file (str): The file path of the class hierarchy.
  21. image_level_ann_file (str): Human-verified image level annotation,
  22. which is used in evaluation.
  23. backend_args (dict, optional): Arguments to instantiate the
  24. corresponding backend. Defaults to None.
  25. """
  26. METAINFO: dict = dict(dataset_type='oid_v6')
  27. def __init__(self,
  28. label_file: str,
  29. meta_file: str,
  30. hierarchy_file: str,
  31. image_level_ann_file: Optional[str] = None,
  32. **kwargs) -> None:
  33. self.label_file = label_file
  34. self.meta_file = meta_file
  35. self.hierarchy_file = hierarchy_file
  36. self.image_level_ann_file = image_level_ann_file
  37. super().__init__(**kwargs)
  38. def load_data_list(self) -> List[dict]:
  39. """Load annotations from an annotation file named as ``self.ann_file``
  40. Returns:
  41. List[dict]: A list of annotation.
  42. """
  43. classes_names, label_id_mapping = self._parse_label_file(
  44. self.label_file)
  45. self._metainfo['classes'] = classes_names
  46. self.label_id_mapping = label_id_mapping
  47. if self.image_level_ann_file is not None:
  48. img_level_anns = self._parse_img_level_ann(
  49. self.image_level_ann_file)
  50. else:
  51. img_level_anns = None
  52. # OpenImagesMetric can get the relation matrix from the dataset meta
  53. relation_matrix = self._get_relation_matrix(self.hierarchy_file)
  54. self._metainfo['RELATION_MATRIX'] = relation_matrix
  55. data_list = []
  56. with get_local_path(
  57. self.ann_file, backend_args=self.backend_args) as local_path:
  58. with open(local_path, 'r') as f:
  59. reader = csv.reader(f)
  60. last_img_id = None
  61. instances = []
  62. for i, line in enumerate(reader):
  63. if i == 0:
  64. continue
  65. img_id = line[0]
  66. if last_img_id is None:
  67. last_img_id = img_id
  68. label_id = line[2]
  69. assert label_id in self.label_id_mapping
  70. label = int(self.label_id_mapping[label_id])
  71. bbox = [
  72. float(line[4]), # xmin
  73. float(line[6]), # ymin
  74. float(line[5]), # xmax
  75. float(line[7]) # ymax
  76. ]
  77. is_occluded = True if int(line[8]) == 1 else False
  78. is_truncated = True if int(line[9]) == 1 else False
  79. is_group_of = True if int(line[10]) == 1 else False
  80. is_depiction = True if int(line[11]) == 1 else False
  81. is_inside = True if int(line[12]) == 1 else False
  82. instance = dict(
  83. bbox=bbox,
  84. bbox_label=label,
  85. ignore_flag=0,
  86. is_occluded=is_occluded,
  87. is_truncated=is_truncated,
  88. is_group_of=is_group_of,
  89. is_depiction=is_depiction,
  90. is_inside=is_inside)
  91. last_img_path = osp.join(self.data_prefix['img'],
  92. f'{last_img_id}.jpg')
  93. if img_id != last_img_id:
  94. # switch to a new image, record previous image's data.
  95. data_info = dict(
  96. img_path=last_img_path,
  97. img_id=last_img_id,
  98. instances=instances,
  99. )
  100. data_list.append(data_info)
  101. instances = []
  102. instances.append(instance)
  103. last_img_id = img_id
  104. data_list.append(
  105. dict(
  106. img_path=last_img_path,
  107. img_id=last_img_id,
  108. instances=instances,
  109. ))
  110. # add image metas to data list
  111. img_metas = load(
  112. self.meta_file, file_format='pkl', backend_args=self.backend_args)
  113. assert len(img_metas) == len(data_list)
  114. for i, meta in enumerate(img_metas):
  115. img_id = data_list[i]['img_id']
  116. assert f'{img_id}.jpg' == osp.split(meta['filename'])[-1]
  117. h, w = meta['ori_shape'][:2]
  118. data_list[i]['height'] = h
  119. data_list[i]['width'] = w
  120. # denormalize bboxes
  121. for j in range(len(data_list[i]['instances'])):
  122. data_list[i]['instances'][j]['bbox'][0] *= w
  123. data_list[i]['instances'][j]['bbox'][2] *= w
  124. data_list[i]['instances'][j]['bbox'][1] *= h
  125. data_list[i]['instances'][j]['bbox'][3] *= h
  126. # add image-level annotation
  127. if img_level_anns is not None:
  128. img_labels = []
  129. confidences = []
  130. img_ann_list = img_level_anns.get(img_id, [])
  131. for ann in img_ann_list:
  132. img_labels.append(int(ann['image_level_label']))
  133. confidences.append(float(ann['confidence']))
  134. data_list[i]['image_level_labels'] = np.array(
  135. img_labels, dtype=np.int64)
  136. data_list[i]['confidences'] = np.array(
  137. confidences, dtype=np.float32)
  138. return data_list
  139. def _parse_label_file(self, label_file: str) -> tuple:
  140. """Get classes name and index mapping from cls-label-description file.
  141. Args:
  142. label_file (str): File path of the label description file that
  143. maps the classes names in MID format to their short
  144. descriptions.
  145. Returns:
  146. tuple: Class name of OpenImages.
  147. """
  148. index_list = []
  149. classes_names = []
  150. with get_local_path(
  151. label_file, backend_args=self.backend_args) as local_path:
  152. with open(local_path, 'r') as f:
  153. reader = csv.reader(f)
  154. for line in reader:
  155. # self.cat2label[line[0]] = line[1]
  156. classes_names.append(line[1])
  157. index_list.append(line[0])
  158. index_mapping = {index: i for i, index in enumerate(index_list)}
  159. return classes_names, index_mapping
  160. def _parse_img_level_ann(self,
  161. img_level_ann_file: str) -> Dict[str, List[dict]]:
  162. """Parse image level annotations from csv style ann_file.
  163. Args:
  164. img_level_ann_file (str): CSV style image level annotation
  165. file path.
  166. Returns:
  167. Dict[str, List[dict]]: Annotations where item of the defaultdict
  168. indicates an image, each of which has (n) dicts.
  169. Keys of dicts are:
  170. - `image_level_label` (int): Label id.
  171. - `confidence` (float): Labels that are human-verified to be
  172. present in an image have confidence = 1 (positive labels).
  173. Labels that are human-verified to be absent from an image
  174. have confidence = 0 (negative labels). Machine-generated
  175. labels have fractional confidences, generally >= 0.5.
  176. The higher the confidence, the smaller the chance for
  177. the label to be a false positive.
  178. """
  179. item_lists = defaultdict(list)
  180. with get_local_path(
  181. img_level_ann_file,
  182. backend_args=self.backend_args) as local_path:
  183. with open(local_path, 'r') as f:
  184. reader = csv.reader(f)
  185. for i, line in enumerate(reader):
  186. if i == 0:
  187. continue
  188. img_id = line[0]
  189. item_lists[img_id].append(
  190. dict(
  191. image_level_label=int(
  192. self.label_id_mapping[line[2]]),
  193. confidence=float(line[3])))
  194. return item_lists
  195. def _get_relation_matrix(self, hierarchy_file: str) -> np.ndarray:
  196. """Get the matrix of class hierarchy from the hierarchy file. Hierarchy
  197. for 600 classes can be found at https://storage.googleapis.com/openimag
  198. es/2018_04/bbox_labels_600_hierarchy_visualizer/circle.html.
  199. Args:
  200. hierarchy_file (str): File path to the hierarchy for classes.
  201. Returns:
  202. np.ndarray: The matrix of the corresponding relationship between
  203. the parent class and the child class, of shape
  204. (class_num, class_num).
  205. """ # noqa
  206. hierarchy = load(
  207. hierarchy_file, file_format='json', backend_args=self.backend_args)
  208. class_num = len(self._metainfo['classes'])
  209. relation_matrix = np.eye(class_num, class_num)
  210. relation_matrix = self._convert_hierarchy_tree(hierarchy,
  211. relation_matrix)
  212. return relation_matrix
  213. def _convert_hierarchy_tree(self,
  214. hierarchy_map: dict,
  215. relation_matrix: np.ndarray,
  216. parents: list = [],
  217. get_all_parents: bool = True) -> np.ndarray:
  218. """Get matrix of the corresponding relationship between the parent
  219. class and the child class.
  220. Args:
  221. hierarchy_map (dict): Including label name and corresponding
  222. subcategory. Keys of dicts are:
  223. - `LabeName` (str): Name of the label.
  224. - `Subcategory` (dict | list): Corresponding subcategory(ies).
  225. relation_matrix (ndarray): The matrix of the corresponding
  226. relationship between the parent class and the child class,
  227. of shape (class_num, class_num).
  228. parents (list): Corresponding parent class.
  229. get_all_parents (bool): Whether get all parent names.
  230. Default: True
  231. Returns:
  232. ndarray: The matrix of the corresponding relationship between
  233. the parent class and the child class, of shape
  234. (class_num, class_num).
  235. """
  236. if 'Subcategory' in hierarchy_map:
  237. for node in hierarchy_map['Subcategory']:
  238. if 'LabelName' in node:
  239. children_name = node['LabelName']
  240. children_index = self.label_id_mapping[children_name]
  241. children = [children_index]
  242. else:
  243. continue
  244. if len(parents) > 0:
  245. for parent_index in parents:
  246. if get_all_parents:
  247. children.append(parent_index)
  248. relation_matrix[children_index, parent_index] = 1
  249. relation_matrix = self._convert_hierarchy_tree(
  250. node, relation_matrix, parents=children)
  251. return relation_matrix
  252. def _join_prefix(self):
  253. """Join ``self.data_root`` with annotation path."""
  254. super()._join_prefix()
  255. if not is_abs(self.label_file) and self.label_file:
  256. self.label_file = osp.join(self.data_root, self.label_file)
  257. if not is_abs(self.meta_file) and self.meta_file:
  258. self.meta_file = osp.join(self.data_root, self.meta_file)
  259. if not is_abs(self.hierarchy_file) and self.hierarchy_file:
  260. self.hierarchy_file = osp.join(self.data_root, self.hierarchy_file)
  261. if self.image_level_ann_file and not is_abs(self.image_level_ann_file):
  262. self.image_level_ann_file = osp.join(self.data_root,
  263. self.image_level_ann_file)
  264. @DATASETS.register_module()
  265. class OpenImagesChallengeDataset(OpenImagesDataset):
  266. """Open Images Challenge dataset for detection.
  267. Args:
  268. ann_file (str): Open Images Challenge box annotation in txt format.
  269. """
  270. METAINFO: dict = dict(dataset_type='oid_challenge')
  271. def __init__(self, ann_file: str, **kwargs) -> None:
  272. if not ann_file.endswith('txt'):
  273. raise TypeError('The annotation file of Open Images Challenge '
  274. 'should be a txt file.')
  275. super().__init__(ann_file=ann_file, **kwargs)
  276. def load_data_list(self) -> List[dict]:
  277. """Load annotations from an annotation file named as ``self.ann_file``
  278. Returns:
  279. List[dict]: A list of annotation.
  280. """
  281. classes_names, label_id_mapping = self._parse_label_file(
  282. self.label_file)
  283. self._metainfo['classes'] = classes_names
  284. self.label_id_mapping = label_id_mapping
  285. if self.image_level_ann_file is not None:
  286. img_level_anns = self._parse_img_level_ann(
  287. self.image_level_ann_file)
  288. else:
  289. img_level_anns = None
  290. # OpenImagesMetric can get the relation matrix from the dataset meta
  291. relation_matrix = self._get_relation_matrix(self.hierarchy_file)
  292. self._metainfo['RELATION_MATRIX'] = relation_matrix
  293. data_list = []
  294. with get_local_path(
  295. self.ann_file, backend_args=self.backend_args) as local_path:
  296. with open(local_path, 'r') as f:
  297. lines = f.readlines()
  298. i = 0
  299. while i < len(lines):
  300. instances = []
  301. filename = lines[i].rstrip()
  302. i += 2
  303. img_gt_size = int(lines[i])
  304. i += 1
  305. for j in range(img_gt_size):
  306. sp = lines[i + j].split()
  307. instances.append(
  308. dict(
  309. bbox=[
  310. float(sp[1]),
  311. float(sp[2]),
  312. float(sp[3]),
  313. float(sp[4])
  314. ],
  315. bbox_label=int(sp[0]) - 1, # labels begin from 1
  316. ignore_flag=0,
  317. is_group_ofs=True if int(sp[5]) == 1 else False))
  318. i += img_gt_size
  319. data_list.append(
  320. dict(
  321. img_path=osp.join(self.data_prefix['img'], filename),
  322. instances=instances,
  323. ))
  324. # add image metas to data list
  325. img_metas = load(
  326. self.meta_file, file_format='pkl', backend_args=self.backend_args)
  327. assert len(img_metas) == len(data_list)
  328. for i, meta in enumerate(img_metas):
  329. img_id = osp.split(data_list[i]['img_path'])[-1][:-4]
  330. assert img_id == osp.split(meta['filename'])[-1][:-4]
  331. h, w = meta['ori_shape'][:2]
  332. data_list[i]['height'] = h
  333. data_list[i]['width'] = w
  334. data_list[i]['img_id'] = img_id
  335. # denormalize bboxes
  336. for j in range(len(data_list[i]['instances'])):
  337. data_list[i]['instances'][j]['bbox'][0] *= w
  338. data_list[i]['instances'][j]['bbox'][2] *= w
  339. data_list[i]['instances'][j]['bbox'][1] *= h
  340. data_list[i]['instances'][j]['bbox'][3] *= h
  341. # add image-level annotation
  342. if img_level_anns is not None:
  343. img_labels = []
  344. confidences = []
  345. img_ann_list = img_level_anns.get(img_id, [])
  346. for ann in img_ann_list:
  347. img_labels.append(int(ann['image_level_label']))
  348. confidences.append(float(ann['confidence']))
  349. data_list[i]['image_level_labels'] = np.array(
  350. img_labels, dtype=np.int64)
  351. data_list[i]['confidences'] = np.array(
  352. confidences, dtype=np.float32)
  353. return data_list
  354. def _parse_label_file(self, label_file: str) -> tuple:
  355. """Get classes name and index mapping from cls-label-description file.
  356. Args:
  357. label_file (str): File path of the label description file that
  358. maps the classes names in MID format to their short
  359. descriptions.
  360. Returns:
  361. tuple: Class name of OpenImages.
  362. """
  363. label_list = []
  364. id_list = []
  365. index_mapping = {}
  366. with get_local_path(
  367. label_file, backend_args=self.backend_args) as local_path:
  368. with open(local_path, 'r') as f:
  369. reader = csv.reader(f)
  370. for line in reader:
  371. label_name = line[0]
  372. label_id = int(line[2])
  373. label_list.append(line[1])
  374. id_list.append(label_id)
  375. index_mapping[label_name] = label_id - 1
  376. indexes = np.argsort(id_list)
  377. classes_names = []
  378. for index in indexes:
  379. classes_names.append(label_list[index])
  380. return classes_names, index_mapping
  381. def _parse_img_level_ann(self, image_level_ann_file):
  382. """Parse image level annotations from csv style ann_file.
  383. Args:
  384. image_level_ann_file (str): CSV style image level annotation
  385. file path.
  386. Returns:
  387. defaultdict[list[dict]]: Annotations where item of the defaultdict
  388. indicates an image, each of which has (n) dicts.
  389. Keys of dicts are:
  390. - `image_level_label` (int): of shape 1.
  391. - `confidence` (float): of shape 1.
  392. """
  393. item_lists = defaultdict(list)
  394. with get_local_path(
  395. image_level_ann_file,
  396. backend_args=self.backend_args) as local_path:
  397. with open(local_path, 'r') as f:
  398. reader = csv.reader(f)
  399. i = -1
  400. for line in reader:
  401. i += 1
  402. if i == 0:
  403. continue
  404. else:
  405. img_id = line[0]
  406. label_id = line[1]
  407. assert label_id in self.label_id_mapping
  408. image_level_label = int(
  409. self.label_id_mapping[label_id])
  410. confidence = float(line[2])
  411. item_lists[img_id].append(
  412. dict(
  413. image_level_label=image_level_label,
  414. confidence=confidence))
  415. return item_lists
  416. def _get_relation_matrix(self, hierarchy_file: str) -> np.ndarray:
  417. """Get the matrix of class hierarchy from the hierarchy file.
  418. Args:
  419. hierarchy_file (str): File path to the hierarchy for classes.
  420. Returns:
  421. np.ndarray: The matrix of the corresponding
  422. relationship between the parent class and the child class,
  423. of shape (class_num, class_num).
  424. """
  425. with get_local_path(
  426. hierarchy_file, backend_args=self.backend_args) as local_path:
  427. class_label_tree = np.load(local_path, allow_pickle=True)
  428. return class_label_tree[1:, 1:]