panoptic_utils.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # Copyright (c) 2018, Alexander Kirillov
  3. # This file supports `backend_args` for `panopticapi`,
  4. # the source code is copied from `panopticapi`,
  5. # only the way to load the gt images is modified.
  6. import multiprocessing
  7. import os
  8. import mmcv
  9. import numpy as np
  10. from mmengine.fileio import get
  11. # A custom value to distinguish instance ID and category ID; need to
  12. # be greater than the number of categories.
  13. # For a pixel in the panoptic result map:
  14. # pan_id = ins_id * INSTANCE_OFFSET + cat_id
  15. INSTANCE_OFFSET = 1000
  16. try:
  17. from panopticapi.evaluation import OFFSET, VOID, PQStat
  18. from panopticapi.utils import rgb2id
  19. except ImportError:
  20. PQStat = None
  21. rgb2id = None
  22. VOID = 0
  23. OFFSET = 256 * 256 * 256
  24. def pq_compute_single_core(proc_id,
  25. annotation_set,
  26. gt_folder,
  27. pred_folder,
  28. categories,
  29. backend_args=None,
  30. print_log=False):
  31. """The single core function to evaluate the metric of Panoptic
  32. Segmentation.
  33. Same as the function with the same name in `panopticapi`. Only the function
  34. to load the images is changed to use the file client.
  35. Args:
  36. proc_id (int): The id of the mini process.
  37. gt_folder (str): The path of the ground truth images.
  38. pred_folder (str): The path of the prediction images.
  39. categories (str): The categories of the dataset.
  40. backend_args (object): The Backend of the dataset. If None,
  41. the backend will be set to `local`.
  42. print_log (bool): Whether to print the log. Defaults to False.
  43. """
  44. if PQStat is None:
  45. raise RuntimeError(
  46. 'panopticapi is not installed, please install it by: '
  47. 'pip install git+https://github.com/cocodataset/'
  48. 'panopticapi.git.')
  49. pq_stat = PQStat()
  50. idx = 0
  51. for gt_ann, pred_ann in annotation_set:
  52. if print_log and idx % 100 == 0:
  53. print('Core: {}, {} from {} images processed'.format(
  54. proc_id, idx, len(annotation_set)))
  55. idx += 1
  56. # The gt images can be on the local disk or `ceph`, so we use
  57. # backend here.
  58. img_bytes = get(
  59. os.path.join(gt_folder, gt_ann['file_name']),
  60. backend_args=backend_args)
  61. pan_gt = mmcv.imfrombytes(img_bytes, flag='color', channel_order='rgb')
  62. pan_gt = rgb2id(pan_gt)
  63. # The predictions can only be on the local dist now.
  64. pan_pred = mmcv.imread(
  65. os.path.join(pred_folder, pred_ann['file_name']),
  66. flag='color',
  67. channel_order='rgb')
  68. pan_pred = rgb2id(pan_pred)
  69. gt_segms = {el['id']: el for el in gt_ann['segments_info']}
  70. pred_segms = {el['id']: el for el in pred_ann['segments_info']}
  71. # predicted segments area calculation + prediction sanity checks
  72. pred_labels_set = set(el['id'] for el in pred_ann['segments_info'])
  73. labels, labels_cnt = np.unique(pan_pred, return_counts=True)
  74. for label, label_cnt in zip(labels, labels_cnt):
  75. if label not in pred_segms:
  76. if label == VOID:
  77. continue
  78. raise KeyError(
  79. 'In the image with ID {} segment with ID {} is '
  80. 'presented in PNG and not presented in JSON.'.format(
  81. gt_ann['image_id'], label))
  82. pred_segms[label]['area'] = label_cnt
  83. pred_labels_set.remove(label)
  84. if pred_segms[label]['category_id'] not in categories:
  85. raise KeyError(
  86. 'In the image with ID {} segment with ID {} has '
  87. 'unknown category_id {}.'.format(
  88. gt_ann['image_id'], label,
  89. pred_segms[label]['category_id']))
  90. if len(pred_labels_set) != 0:
  91. raise KeyError(
  92. 'In the image with ID {} the following segment IDs {} '
  93. 'are presented in JSON and not presented in PNG.'.format(
  94. gt_ann['image_id'], list(pred_labels_set)))
  95. # confusion matrix calculation
  96. pan_gt_pred = pan_gt.astype(np.uint64) * OFFSET + pan_pred.astype(
  97. np.uint64)
  98. gt_pred_map = {}
  99. labels, labels_cnt = np.unique(pan_gt_pred, return_counts=True)
  100. for label, intersection in zip(labels, labels_cnt):
  101. gt_id = label // OFFSET
  102. pred_id = label % OFFSET
  103. gt_pred_map[(gt_id, pred_id)] = intersection
  104. # count all matched pairs
  105. gt_matched = set()
  106. pred_matched = set()
  107. for label_tuple, intersection in gt_pred_map.items():
  108. gt_label, pred_label = label_tuple
  109. if gt_label not in gt_segms:
  110. continue
  111. if pred_label not in pred_segms:
  112. continue
  113. if gt_segms[gt_label]['iscrowd'] == 1:
  114. continue
  115. if gt_segms[gt_label]['category_id'] != pred_segms[pred_label][
  116. 'category_id']:
  117. continue
  118. union = pred_segms[pred_label]['area'] + gt_segms[gt_label][
  119. 'area'] - intersection - gt_pred_map.get((VOID, pred_label), 0)
  120. iou = intersection / union
  121. if iou > 0.5:
  122. pq_stat[gt_segms[gt_label]['category_id']].tp += 1
  123. pq_stat[gt_segms[gt_label]['category_id']].iou += iou
  124. gt_matched.add(gt_label)
  125. pred_matched.add(pred_label)
  126. # count false positives
  127. crowd_labels_dict = {}
  128. for gt_label, gt_info in gt_segms.items():
  129. if gt_label in gt_matched:
  130. continue
  131. # crowd segments are ignored
  132. if gt_info['iscrowd'] == 1:
  133. crowd_labels_dict[gt_info['category_id']] = gt_label
  134. continue
  135. pq_stat[gt_info['category_id']].fn += 1
  136. # count false positives
  137. for pred_label, pred_info in pred_segms.items():
  138. if pred_label in pred_matched:
  139. continue
  140. # intersection of the segment with VOID
  141. intersection = gt_pred_map.get((VOID, pred_label), 0)
  142. # plus intersection with corresponding CROWD region if it exists
  143. if pred_info['category_id'] in crowd_labels_dict:
  144. intersection += gt_pred_map.get(
  145. (crowd_labels_dict[pred_info['category_id']], pred_label),
  146. 0)
  147. # predicted segment is ignored if more than half of
  148. # the segment correspond to VOID and CROWD regions
  149. if intersection / pred_info['area'] > 0.5:
  150. continue
  151. pq_stat[pred_info['category_id']].fp += 1
  152. if print_log:
  153. print('Core: {}, all {} images processed'.format(
  154. proc_id, len(annotation_set)))
  155. return pq_stat
  156. def pq_compute_multi_core(matched_annotations_list,
  157. gt_folder,
  158. pred_folder,
  159. categories,
  160. backend_args=None,
  161. nproc=32):
  162. """Evaluate the metrics of Panoptic Segmentation with multithreading.
  163. Same as the function with the same name in `panopticapi`.
  164. Args:
  165. matched_annotations_list (list): The matched annotation list. Each
  166. element is a tuple of annotations of the same image with the
  167. format (gt_anns, pred_anns).
  168. gt_folder (str): The path of the ground truth images.
  169. pred_folder (str): The path of the prediction images.
  170. categories (str): The categories of the dataset.
  171. backend_args (object): The file client of the dataset. If None,
  172. the backend will be set to `local`.
  173. nproc (int): Number of processes for panoptic quality computing.
  174. Defaults to 32. When `nproc` exceeds the number of cpu cores,
  175. the number of cpu cores is used.
  176. """
  177. if PQStat is None:
  178. raise RuntimeError(
  179. 'panopticapi is not installed, please install it by: '
  180. 'pip install git+https://github.com/cocodataset/'
  181. 'panopticapi.git.')
  182. cpu_num = min(nproc, multiprocessing.cpu_count())
  183. annotations_split = np.array_split(matched_annotations_list, cpu_num)
  184. print('Number of cores: {}, images per core: {}'.format(
  185. cpu_num, len(annotations_split[0])))
  186. workers = multiprocessing.Pool(processes=cpu_num)
  187. processes = []
  188. for proc_id, annotation_set in enumerate(annotations_split):
  189. p = workers.apply_async(pq_compute_single_core,
  190. (proc_id, annotation_set, gt_folder,
  191. pred_folder, categories, backend_args))
  192. processes.append(p)
  193. # Close the process pool, otherwise it will lead to memory
  194. # leaking problems.
  195. workers.close()
  196. workers.join()
  197. pq_stat = PQStat()
  198. for p in processes:
  199. pq_stat += p.get()
  200. return pq_stat