cityscapes_utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # Copyright (c) https://github.com/mcordts/cityscapesScripts
  3. # A wrapper of `cityscapesscripts` which supports loading groundtruth
  4. # image from `backend_args`.
  5. import json
  6. import os
  7. import sys
  8. from pathlib import Path
  9. from typing import Optional, Union
  10. import mmcv
  11. import numpy as np
  12. from mmengine.fileio import get
  13. try:
  14. import cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling as CSEval # noqa: E501
  15. from cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling import \
  16. CArgs # noqa: E501
  17. from cityscapesscripts.evaluation.instance import Instance
  18. from cityscapesscripts.helpers.csHelpers import (id2label, labels,
  19. writeDict2JSON)
  20. HAS_CITYSCAPESAPI = True
  21. except ImportError:
  22. CArgs = object
  23. HAS_CITYSCAPESAPI = False
  24. def evaluateImgLists(prediction_list: list,
  25. groundtruth_list: list,
  26. args: CArgs,
  27. backend_args: Optional[dict] = None,
  28. dump_matches: bool = False) -> dict:
  29. """A wrapper of obj:``cityscapesscripts.evaluation.
  30. evalInstanceLevelSemanticLabeling.evaluateImgLists``. Support loading
  31. groundtruth image from file backend.
  32. Args:
  33. prediction_list (list): A list of prediction txt file.
  34. groundtruth_list (list): A list of groundtruth image file.
  35. args (CArgs): A global object setting in
  36. obj:``cityscapesscripts.evaluation.
  37. evalInstanceLevelSemanticLabeling``
  38. backend_args (dict, optional): Arguments to instantiate the
  39. preifx of uri corresponding backend. Defaults to None.
  40. dump_matches (bool): whether dump matches.json. Defaults to False.
  41. Returns:
  42. dict: The computed metric.
  43. """
  44. if not HAS_CITYSCAPESAPI:
  45. raise RuntimeError('Failed to import `cityscapesscripts`.'
  46. 'Please try to install official '
  47. 'cityscapesscripts by '
  48. '"pip install cityscapesscripts"')
  49. # determine labels of interest
  50. CSEval.setInstanceLabels(args)
  51. # get dictionary of all ground truth instances
  52. gt_instances = getGtInstances(
  53. groundtruth_list, args, backend_args=backend_args)
  54. # match predictions and ground truth
  55. matches = matchGtWithPreds(prediction_list, groundtruth_list, gt_instances,
  56. args, backend_args)
  57. if dump_matches:
  58. CSEval.writeDict2JSON(matches, 'matches.json')
  59. # evaluate matches
  60. apScores = CSEval.evaluateMatches(matches, args)
  61. # averages
  62. avgDict = CSEval.computeAverages(apScores, args)
  63. # result dict
  64. resDict = CSEval.prepareJSONDataForResults(avgDict, apScores, args)
  65. if args.JSONOutput:
  66. # create output folder if necessary
  67. path = os.path.dirname(args.exportFile)
  68. CSEval.ensurePath(path)
  69. # Write APs to JSON
  70. CSEval.writeDict2JSON(resDict, args.exportFile)
  71. CSEval.printResults(avgDict, args)
  72. return resDict
  73. def matchGtWithPreds(prediction_list: list,
  74. groundtruth_list: list,
  75. gt_instances: dict,
  76. args: CArgs,
  77. backend_args=None):
  78. """A wrapper of obj:``cityscapesscripts.evaluation.
  79. evalInstanceLevelSemanticLabeling.matchGtWithPreds``. Support loading
  80. groundtruth image from file backend.
  81. Args:
  82. prediction_list (list): A list of prediction txt file.
  83. groundtruth_list (list): A list of groundtruth image file.
  84. gt_instances (dict): Groundtruth dict.
  85. args (CArgs): A global object setting in
  86. obj:``cityscapesscripts.evaluation.
  87. evalInstanceLevelSemanticLabeling``
  88. backend_args (dict, optional): Arguments to instantiate the
  89. preifx of uri corresponding backend. Defaults to None.
  90. Returns:
  91. dict: The processed prediction and groundtruth result.
  92. """
  93. if not HAS_CITYSCAPESAPI:
  94. raise RuntimeError('Failed to import `cityscapesscripts`.'
  95. 'Please try to install official '
  96. 'cityscapesscripts by '
  97. '"pip install cityscapesscripts"')
  98. matches: dict = dict()
  99. if not args.quiet:
  100. print(f'Matching {len(prediction_list)} pairs of images...')
  101. count = 0
  102. for (pred, gt) in zip(prediction_list, groundtruth_list):
  103. # Read input files
  104. gt_image = readGTImage(gt, backend_args)
  105. pred_info = readPredInfo(pred)
  106. # Get and filter ground truth instances
  107. unfiltered_instances = gt_instances[gt]
  108. cur_gt_instances_orig = CSEval.filterGtInstances(
  109. unfiltered_instances, args)
  110. # Try to assign all predictions
  111. (cur_gt_instances,
  112. cur_pred_instances) = CSEval.assignGt2Preds(cur_gt_instances_orig,
  113. gt_image, pred_info, args)
  114. # append to global dict
  115. matches[gt] = {}
  116. matches[gt]['groundTruth'] = cur_gt_instances
  117. matches[gt]['prediction'] = cur_pred_instances
  118. count += 1
  119. if not args.quiet:
  120. print(f'\rImages Processed: {count}', end=' ')
  121. sys.stdout.flush()
  122. if not args.quiet:
  123. print('')
  124. return matches
  125. def readGTImage(image_file: Union[str, Path],
  126. backend_args: Optional[dict] = None) -> np.ndarray:
  127. """Read an image from path.
  128. Same as obj:``cityscapesscripts.evaluation.
  129. evalInstanceLevelSemanticLabeling.readGTImage``, but support loading
  130. groundtruth image from file backend.
  131. Args:
  132. image_file (str or Path): Either a str or pathlib.Path.
  133. backend_args (dict, optional): Instantiates the corresponding file
  134. backend. It may contain `backend` key to specify the file
  135. backend. If it contains, the file backend corresponding to this
  136. value will be used and initialized with the remaining values,
  137. otherwise the corresponding file backend will be selected
  138. based on the prefix of the file path. Defaults to None.
  139. Returns:
  140. np.ndarray: The groundtruth image.
  141. """
  142. img_bytes = get(image_file, backend_args=backend_args)
  143. img = mmcv.imfrombytes(img_bytes, flag='unchanged', backend='pillow')
  144. return img
  145. def readPredInfo(prediction_file: str) -> dict:
  146. """A wrapper of obj:``cityscapesscripts.evaluation.
  147. evalInstanceLevelSemanticLabeling.readPredInfo``.
  148. Args:
  149. prediction_file (str): The prediction txt file.
  150. Returns:
  151. dict: The processed prediction results.
  152. """
  153. if not HAS_CITYSCAPESAPI:
  154. raise RuntimeError('Failed to import `cityscapesscripts`.'
  155. 'Please try to install official '
  156. 'cityscapesscripts by '
  157. '"pip install cityscapesscripts"')
  158. printError = CSEval.printError
  159. predInfo = {}
  160. if (not os.path.isfile(prediction_file)):
  161. printError(f"Infofile '{prediction_file}' "
  162. 'for the predictions not found.')
  163. with open(prediction_file) as f:
  164. for line in f:
  165. splittedLine = line.split(' ')
  166. if len(splittedLine) != 3:
  167. printError('Invalid prediction file. Expected content: '
  168. 'relPathPrediction1 labelIDPrediction1 '
  169. 'confidencePrediction1')
  170. if os.path.isabs(splittedLine[0]):
  171. printError('Invalid prediction file. First entry in each '
  172. 'line must be a relative path.')
  173. filename = os.path.join(
  174. os.path.dirname(prediction_file), splittedLine[0])
  175. imageInfo = {}
  176. imageInfo['labelID'] = int(float(splittedLine[1]))
  177. imageInfo['conf'] = float(splittedLine[2]) # type: ignore
  178. predInfo[filename] = imageInfo
  179. return predInfo
  180. def getGtInstances(groundtruth_list: list,
  181. args: CArgs,
  182. backend_args: Optional[dict] = None) -> dict:
  183. """A wrapper of obj:``cityscapesscripts.evaluation.
  184. evalInstanceLevelSemanticLabeling.getGtInstances``. Support loading
  185. groundtruth image from file backend.
  186. Args:
  187. groundtruth_list (list): A list of groundtruth image file.
  188. args (CArgs): A global object setting in
  189. obj:``cityscapesscripts.evaluation.
  190. evalInstanceLevelSemanticLabeling``
  191. backend_args (dict, optional): Arguments to instantiate the
  192. preifx of uri corresponding backend. Defaults to None.
  193. Returns:
  194. dict: The computed metric.
  195. """
  196. if not HAS_CITYSCAPESAPI:
  197. raise RuntimeError('Failed to import `cityscapesscripts`.'
  198. 'Please try to install official '
  199. 'cityscapesscripts by '
  200. '"pip install cityscapesscripts"')
  201. # if there is a global statistics json, then load it
  202. if (os.path.isfile(args.gtInstancesFile)):
  203. if not args.quiet:
  204. print('Loading ground truth instances from JSON.')
  205. with open(args.gtInstancesFile) as json_file:
  206. gt_instances = json.load(json_file)
  207. # otherwise create it
  208. else:
  209. if (not args.quiet):
  210. print('Creating ground truth instances from png files.')
  211. gt_instances = instances2dict(
  212. groundtruth_list, args, backend_args=backend_args)
  213. writeDict2JSON(gt_instances, args.gtInstancesFile)
  214. return gt_instances
  215. def instances2dict(image_list: list,
  216. args: CArgs,
  217. backend_args: Optional[dict] = None) -> dict:
  218. """A wrapper of obj:``cityscapesscripts.evaluation.
  219. evalInstanceLevelSemanticLabeling.instances2dict``. Support loading
  220. groundtruth image from file backend.
  221. Args:
  222. image_list (list): A list of image file.
  223. args (CArgs): A global object setting in
  224. obj:``cityscapesscripts.evaluation.
  225. evalInstanceLevelSemanticLabeling``
  226. backend_args (dict, optional): Arguments to instantiate the
  227. preifx of uri corresponding backend. Defaults to None.
  228. Returns:
  229. dict: The processed groundtruth results.
  230. """
  231. if not HAS_CITYSCAPESAPI:
  232. raise RuntimeError('Failed to import `cityscapesscripts`.'
  233. 'Please try to install official '
  234. 'cityscapesscripts by '
  235. '"pip install cityscapesscripts"')
  236. imgCount = 0
  237. instanceDict = {}
  238. if not isinstance(image_list, list):
  239. image_list = [image_list]
  240. if not args.quiet:
  241. print(f'Processing {len(image_list)} images...')
  242. for image_name in image_list:
  243. # Load image
  244. img_bytes = get(image_name, backend_args=backend_args)
  245. imgNp = mmcv.imfrombytes(img_bytes, flag='unchanged', backend='pillow')
  246. # Initialize label categories
  247. instances: dict = {}
  248. for label in labels:
  249. instances[label.name] = []
  250. # Loop through all instance ids in instance image
  251. for instanceId in np.unique(imgNp):
  252. instanceObj = Instance(imgNp, instanceId)
  253. instances[id2label[instanceObj.labelID].name].append(
  254. instanceObj.toDict())
  255. instanceDict[image_name] = instances
  256. imgCount += 1
  257. if not args.quiet:
  258. print(f'\rImages Processed: {imgCount}', end=' ')
  259. sys.stdout.flush()
  260. return instanceDict