coco_error_analysis.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import os
  4. from argparse import ArgumentParser
  5. from multiprocessing import Pool
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. from pycocotools.coco import COCO
  9. from pycocotools.cocoeval import COCOeval
  10. def makeplot(rs, ps, outDir, class_name, iou_type):
  11. cs = np.vstack([
  12. np.ones((2, 3)),
  13. np.array([0.31, 0.51, 0.74]),
  14. np.array([0.75, 0.31, 0.30]),
  15. np.array([0.36, 0.90, 0.38]),
  16. np.array([0.50, 0.39, 0.64]),
  17. np.array([1, 0.6, 0]),
  18. ])
  19. areaNames = ['allarea', 'small', 'medium', 'large']
  20. types = ['C75', 'C50', 'Loc', 'Sim', 'Oth', 'BG', 'FN']
  21. for i in range(len(areaNames)):
  22. area_ps = ps[..., i, 0]
  23. figure_title = iou_type + '-' + class_name + '-' + areaNames[i]
  24. aps = [ps_.mean() for ps_ in area_ps]
  25. ps_curve = [
  26. ps_.mean(axis=1) if ps_.ndim > 1 else ps_ for ps_ in area_ps
  27. ]
  28. ps_curve.insert(0, np.zeros(ps_curve[0].shape))
  29. fig = plt.figure()
  30. ax = plt.subplot(111)
  31. for k in range(len(types)):
  32. ax.plot(rs, ps_curve[k + 1], color=[0, 0, 0], linewidth=0.5)
  33. ax.fill_between(
  34. rs,
  35. ps_curve[k],
  36. ps_curve[k + 1],
  37. color=cs[k],
  38. label=str(f'[{aps[k]:.3f}]' + types[k]),
  39. )
  40. plt.xlabel('recall')
  41. plt.ylabel('precision')
  42. plt.xlim(0, 1.0)
  43. plt.ylim(0, 1.0)
  44. plt.title(figure_title)
  45. plt.legend()
  46. # plt.show()
  47. fig.savefig(outDir + f'/{figure_title}.png')
  48. plt.close(fig)
  49. def autolabel(ax, rects):
  50. """Attach a text label above each bar in *rects*, displaying its height."""
  51. for rect in rects:
  52. height = rect.get_height()
  53. if height > 0 and height <= 1: # for percent values
  54. text_label = '{:2.0f}'.format(height * 100)
  55. else:
  56. text_label = '{:2.0f}'.format(height)
  57. ax.annotate(
  58. text_label,
  59. xy=(rect.get_x() + rect.get_width() / 2, height),
  60. xytext=(0, 3), # 3 points vertical offset
  61. textcoords='offset points',
  62. ha='center',
  63. va='bottom',
  64. fontsize='x-small',
  65. )
  66. def makebarplot(rs, ps, outDir, class_name, iou_type):
  67. areaNames = ['allarea', 'small', 'medium', 'large']
  68. types = ['C75', 'C50', 'Loc', 'Sim', 'Oth', 'BG', 'FN']
  69. fig, ax = plt.subplots()
  70. x = np.arange(len(areaNames)) # the areaNames locations
  71. width = 0.60 # the width of the bars
  72. rects_list = []
  73. figure_title = iou_type + '-' + class_name + '-' + 'ap bar plot'
  74. for i in range(len(types) - 1):
  75. type_ps = ps[i, ..., 0]
  76. aps = [ps_.mean() for ps_ in type_ps.T]
  77. rects_list.append(
  78. ax.bar(
  79. x - width / 2 + (i + 1) * width / len(types),
  80. aps,
  81. width / len(types),
  82. label=types[i],
  83. ))
  84. # Add some text for labels, title and custom x-axis tick labels, etc.
  85. ax.set_ylabel('Mean Average Precision (mAP)')
  86. ax.set_title(figure_title)
  87. ax.set_xticks(x)
  88. ax.set_xticklabels(areaNames)
  89. ax.legend()
  90. # Add score texts over bars
  91. for rects in rects_list:
  92. autolabel(ax, rects)
  93. # Save plot
  94. fig.savefig(outDir + f'/{figure_title}.png')
  95. plt.close(fig)
  96. def get_gt_area_group_numbers(cocoEval):
  97. areaRng = cocoEval.params.areaRng
  98. areaRngStr = [str(aRng) for aRng in areaRng]
  99. areaRngLbl = cocoEval.params.areaRngLbl
  100. areaRngStr2areaRngLbl = dict(zip(areaRngStr, areaRngLbl))
  101. areaRngLbl2Number = dict.fromkeys(areaRngLbl, 0)
  102. for evalImg in cocoEval.evalImgs:
  103. if evalImg:
  104. for gtIgnore in evalImg['gtIgnore']:
  105. if not gtIgnore:
  106. aRngLbl = areaRngStr2areaRngLbl[str(evalImg['aRng'])]
  107. areaRngLbl2Number[aRngLbl] += 1
  108. return areaRngLbl2Number
  109. def make_gt_area_group_numbers_plot(cocoEval, outDir, verbose=True):
  110. areaRngLbl2Number = get_gt_area_group_numbers(cocoEval)
  111. areaRngLbl = areaRngLbl2Number.keys()
  112. if verbose:
  113. print('number of annotations per area group:', areaRngLbl2Number)
  114. # Init figure
  115. fig, ax = plt.subplots()
  116. x = np.arange(len(areaRngLbl)) # the areaNames locations
  117. width = 0.60 # the width of the bars
  118. figure_title = 'number of annotations per area group'
  119. rects = ax.bar(x, areaRngLbl2Number.values(), width)
  120. # Add some text for labels, title and custom x-axis tick labels, etc.
  121. ax.set_ylabel('Number of annotations')
  122. ax.set_title(figure_title)
  123. ax.set_xticks(x)
  124. ax.set_xticklabels(areaRngLbl)
  125. # Add score texts over bars
  126. autolabel(ax, rects)
  127. # Save plot
  128. fig.tight_layout()
  129. fig.savefig(outDir + f'/{figure_title}.png')
  130. plt.close(fig)
  131. def make_gt_area_histogram_plot(cocoEval, outDir):
  132. n_bins = 100
  133. areas = [ann['area'] for ann in cocoEval.cocoGt.anns.values()]
  134. # init figure
  135. figure_title = 'gt annotation areas histogram plot'
  136. fig, ax = plt.subplots()
  137. # Set the number of bins
  138. ax.hist(np.sqrt(areas), bins=n_bins)
  139. # Add some text for labels, title and custom x-axis tick labels, etc.
  140. ax.set_xlabel('Squareroot Area')
  141. ax.set_ylabel('Number of annotations')
  142. ax.set_title(figure_title)
  143. # Save plot
  144. fig.tight_layout()
  145. fig.savefig(outDir + f'/{figure_title}.png')
  146. plt.close(fig)
  147. def analyze_individual_category(k,
  148. cocoDt,
  149. cocoGt,
  150. catId,
  151. iou_type,
  152. areas=None):
  153. nm = cocoGt.loadCats(catId)[0]
  154. print(f'--------------analyzing {k + 1}-{nm["name"]}---------------')
  155. ps_ = {}
  156. dt = copy.deepcopy(cocoDt)
  157. nm = cocoGt.loadCats(catId)[0]
  158. imgIds = cocoGt.getImgIds()
  159. dt_anns = dt.dataset['annotations']
  160. select_dt_anns = []
  161. for ann in dt_anns:
  162. if ann['category_id'] == catId:
  163. select_dt_anns.append(ann)
  164. dt.dataset['annotations'] = select_dt_anns
  165. dt.createIndex()
  166. # compute precision but ignore superclass confusion
  167. gt = copy.deepcopy(cocoGt)
  168. child_catIds = gt.getCatIds(supNms=[nm['supercategory']])
  169. for idx, ann in enumerate(gt.dataset['annotations']):
  170. if ann['category_id'] in child_catIds and ann['category_id'] != catId:
  171. gt.dataset['annotations'][idx]['ignore'] = 1
  172. gt.dataset['annotations'][idx]['iscrowd'] = 1
  173. gt.dataset['annotations'][idx]['category_id'] = catId
  174. cocoEval = COCOeval(gt, copy.deepcopy(dt), iou_type)
  175. cocoEval.params.imgIds = imgIds
  176. cocoEval.params.maxDets = [100]
  177. cocoEval.params.iouThrs = [0.1]
  178. cocoEval.params.useCats = 1
  179. if areas:
  180. cocoEval.params.areaRng = [[0**2, areas[2]], [0**2, areas[0]],
  181. [areas[0], areas[1]], [areas[1], areas[2]]]
  182. cocoEval.evaluate()
  183. cocoEval.accumulate()
  184. ps_supercategory = cocoEval.eval['precision'][0, :, k, :, :]
  185. ps_['ps_supercategory'] = ps_supercategory
  186. # compute precision but ignore any class confusion
  187. gt = copy.deepcopy(cocoGt)
  188. for idx, ann in enumerate(gt.dataset['annotations']):
  189. if ann['category_id'] != catId:
  190. gt.dataset['annotations'][idx]['ignore'] = 1
  191. gt.dataset['annotations'][idx]['iscrowd'] = 1
  192. gt.dataset['annotations'][idx]['category_id'] = catId
  193. cocoEval = COCOeval(gt, copy.deepcopy(dt), iou_type)
  194. cocoEval.params.imgIds = imgIds
  195. cocoEval.params.maxDets = [100]
  196. cocoEval.params.iouThrs = [0.1]
  197. cocoEval.params.useCats = 1
  198. if areas:
  199. cocoEval.params.areaRng = [[0**2, areas[2]], [0**2, areas[0]],
  200. [areas[0], areas[1]], [areas[1], areas[2]]]
  201. cocoEval.evaluate()
  202. cocoEval.accumulate()
  203. ps_allcategory = cocoEval.eval['precision'][0, :, k, :, :]
  204. ps_['ps_allcategory'] = ps_allcategory
  205. return k, ps_
  206. def analyze_results(res_file,
  207. ann_file,
  208. res_types,
  209. out_dir,
  210. extraplots=None,
  211. areas=None):
  212. for res_type in res_types:
  213. assert res_type in ['bbox', 'segm']
  214. if areas:
  215. assert len(areas) == 3, '3 integers should be specified as areas, \
  216. representing 3 area regions'
  217. directory = os.path.dirname(out_dir + '/')
  218. if not os.path.exists(directory):
  219. print(f'-------------create {out_dir}-----------------')
  220. os.makedirs(directory)
  221. cocoGt = COCO(ann_file)
  222. cocoDt = cocoGt.loadRes(res_file)
  223. imgIds = cocoGt.getImgIds()
  224. for res_type in res_types:
  225. res_out_dir = out_dir + '/' + res_type + '/'
  226. res_directory = os.path.dirname(res_out_dir)
  227. if not os.path.exists(res_directory):
  228. print(f'-------------create {res_out_dir}-----------------')
  229. os.makedirs(res_directory)
  230. iou_type = res_type
  231. cocoEval = COCOeval(
  232. copy.deepcopy(cocoGt), copy.deepcopy(cocoDt), iou_type)
  233. cocoEval.params.imgIds = imgIds
  234. cocoEval.params.iouThrs = [0.75, 0.5, 0.1]
  235. cocoEval.params.maxDets = [100]
  236. if areas:
  237. cocoEval.params.areaRng = [[0**2, areas[2]], [0**2, areas[0]],
  238. [areas[0], areas[1]],
  239. [areas[1], areas[2]]]
  240. cocoEval.evaluate()
  241. cocoEval.accumulate()
  242. ps = cocoEval.eval['precision']
  243. ps = np.vstack([ps, np.zeros((4, *ps.shape[1:]))])
  244. catIds = cocoGt.getCatIds()
  245. recThrs = cocoEval.params.recThrs
  246. with Pool(processes=48) as pool:
  247. args = [(k, cocoDt, cocoGt, catId, iou_type, areas)
  248. for k, catId in enumerate(catIds)]
  249. analyze_results = pool.starmap(analyze_individual_category, args)
  250. for k, catId in enumerate(catIds):
  251. nm = cocoGt.loadCats(catId)[0]
  252. print(f'--------------saving {k + 1}-{nm["name"]}---------------')
  253. analyze_result = analyze_results[k]
  254. assert k == analyze_result[0]
  255. ps_supercategory = analyze_result[1]['ps_supercategory']
  256. ps_allcategory = analyze_result[1]['ps_allcategory']
  257. # compute precision but ignore superclass confusion
  258. ps[3, :, k, :, :] = ps_supercategory
  259. # compute precision but ignore any class confusion
  260. ps[4, :, k, :, :] = ps_allcategory
  261. # fill in background and false negative errors and plot
  262. ps[ps == -1] = 0
  263. ps[5, :, k, :, :] = ps[4, :, k, :, :] > 0
  264. ps[6, :, k, :, :] = 1.0
  265. makeplot(recThrs, ps[:, :, k], res_out_dir, nm['name'], iou_type)
  266. if extraplots:
  267. makebarplot(recThrs, ps[:, :, k], res_out_dir, nm['name'],
  268. iou_type)
  269. makeplot(recThrs, ps, res_out_dir, 'allclass', iou_type)
  270. if extraplots:
  271. makebarplot(recThrs, ps, res_out_dir, 'allclass', iou_type)
  272. make_gt_area_group_numbers_plot(
  273. cocoEval=cocoEval, outDir=res_out_dir, verbose=True)
  274. make_gt_area_histogram_plot(cocoEval=cocoEval, outDir=res_out_dir)
  275. def main():
  276. parser = ArgumentParser(description='COCO Error Analysis Tool')
  277. parser.add_argument('result', help='result file (json format) path')
  278. parser.add_argument('out_dir', help='dir to save analyze result images')
  279. parser.add_argument(
  280. '--ann',
  281. default='data/coco/annotations/instances_val2017.json',
  282. help='annotation file path')
  283. parser.add_argument(
  284. '--types', type=str, nargs='+', default=['bbox'], help='result types')
  285. parser.add_argument(
  286. '--extraplots',
  287. action='store_true',
  288. help='export extra bar/stat plots')
  289. parser.add_argument(
  290. '--areas',
  291. type=int,
  292. nargs='+',
  293. default=[1024, 9216, 10000000000],
  294. help='area regions')
  295. args = parser.parse_args()
  296. analyze_results(
  297. args.result,
  298. args.ann,
  299. args.types,
  300. out_dir=args.out_dir,
  301. extraplots=args.extraplots,
  302. areas=args.areas)
  303. if __name__ == '__main__':
  304. main()