robustness_eval.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. from argparse import ArgumentParser
  4. import numpy as np
  5. from mmengine.fileio import load
  6. def print_coco_results(results):
  7. def _print(result, ap=1, iouThr=None, areaRng='all', maxDets=100):
  8. titleStr = 'Average Precision' if ap == 1 else 'Average Recall'
  9. typeStr = '(AP)' if ap == 1 else '(AR)'
  10. iouStr = '0.50:0.95' \
  11. if iouThr is None else f'{iouThr:0.2f}'
  12. iStr = f' {titleStr:<18} {typeStr} @[ IoU={iouStr:<9} | '
  13. iStr += f'area={areaRng:>6s} | maxDets={maxDets:>3d} ] = {result:0.3f}'
  14. print(iStr)
  15. stats = np.zeros((12, ))
  16. stats[0] = _print(results[0], 1)
  17. stats[1] = _print(results[1], 1, iouThr=.5)
  18. stats[2] = _print(results[2], 1, iouThr=.75)
  19. stats[3] = _print(results[3], 1, areaRng='small')
  20. stats[4] = _print(results[4], 1, areaRng='medium')
  21. stats[5] = _print(results[5], 1, areaRng='large')
  22. # TODO support recall metric
  23. '''
  24. stats[6] = _print(results[6], 0, maxDets=1)
  25. stats[7] = _print(results[7], 0, maxDets=10)
  26. stats[8] = _print(results[8], 0)
  27. stats[9] = _print(results[9], 0, areaRng='small')
  28. stats[10] = _print(results[10], 0, areaRng='medium')
  29. stats[11] = _print(results[11], 0, areaRng='large')
  30. '''
  31. def get_coco_style_results(filename,
  32. task='bbox',
  33. metric=None,
  34. prints='mPC',
  35. aggregate='benchmark'):
  36. assert aggregate in ['benchmark', 'all']
  37. if prints == 'all':
  38. prints = ['P', 'mPC', 'rPC']
  39. elif isinstance(prints, str):
  40. prints = [prints]
  41. for p in prints:
  42. assert p in ['P', 'mPC', 'rPC']
  43. if metric is None:
  44. metrics = [
  45. 'mAP',
  46. 'mAP_50',
  47. 'mAP_75',
  48. 'mAP_s',
  49. 'mAP_m',
  50. 'mAP_l',
  51. ]
  52. elif isinstance(metric, list):
  53. metrics = metric
  54. else:
  55. metrics = [metric]
  56. for metric_name in metrics:
  57. assert metric_name in [
  58. 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'
  59. ]
  60. eval_output = load(filename)
  61. num_distortions = len(list(eval_output.keys()))
  62. results = np.zeros((num_distortions, 6, len(metrics)), dtype='float32')
  63. for corr_i, distortion in enumerate(eval_output):
  64. for severity in eval_output[distortion]:
  65. for metric_j, metric_name in enumerate(metrics):
  66. metric_dict = eval_output[distortion][severity]
  67. new_metric_dict = {}
  68. for k, v in metric_dict.items():
  69. if '/' in k:
  70. new_metric_dict[k.split('/')[-1]] = v
  71. mAP = new_metric_dict['_'.join((task, metric_name))]
  72. results[corr_i, severity, metric_j] = mAP
  73. P = results[0, 0, :]
  74. if aggregate == 'benchmark':
  75. mPC = np.mean(results[:15, 1:, :], axis=(0, 1))
  76. else:
  77. mPC = np.mean(results[:, 1:, :], axis=(0, 1))
  78. rPC = mPC / P
  79. print(f'\nmodel: {osp.basename(filename)}')
  80. if metric is None:
  81. if 'P' in prints:
  82. print(f'Performance on Clean Data [P] ({task})')
  83. print_coco_results(P)
  84. if 'mPC' in prints:
  85. print(f'Mean Performance under Corruption [mPC] ({task})')
  86. print_coco_results(mPC)
  87. if 'rPC' in prints:
  88. print(f'Relative Performance under Corruption [rPC] ({task})')
  89. print_coco_results(rPC)
  90. else:
  91. if 'P' in prints:
  92. print(f'Performance on Clean Data [P] ({task})')
  93. for metric_i, metric_name in enumerate(metrics):
  94. print(f'{metric_name:5} = {P[metric_i]:0.3f}')
  95. if 'mPC' in prints:
  96. print(f'Mean Performance under Corruption [mPC] ({task})')
  97. for metric_i, metric_name in enumerate(metrics):
  98. print(f'{metric_name:5} = {mPC[metric_i]:0.3f}')
  99. if 'rPC' in prints:
  100. print(f'Relative Performance under Corruption [rPC] ({task})')
  101. for metric_i, metric_name in enumerate(metrics):
  102. print(f'{metric_name:5} => {rPC[metric_i] * 100:0.1f} %')
  103. return results
  104. def get_voc_style_results(filename, prints='mPC', aggregate='benchmark'):
  105. assert aggregate in ['benchmark', 'all']
  106. if prints == 'all':
  107. prints = ['P', 'mPC', 'rPC']
  108. elif isinstance(prints, str):
  109. prints = [prints]
  110. for p in prints:
  111. assert p in ['P', 'mPC', 'rPC']
  112. eval_output = load(filename)
  113. num_distortions = len(list(eval_output.keys()))
  114. results = np.zeros((num_distortions, 6, 20), dtype='float32')
  115. for i, distortion in enumerate(eval_output):
  116. for severity in eval_output[distortion]:
  117. mAP = [
  118. eval_output[distortion][severity][j]['ap']
  119. for j in range(len(eval_output[distortion][severity]))
  120. ]
  121. results[i, severity, :] = mAP
  122. P = results[0, 0, :]
  123. if aggregate == 'benchmark':
  124. mPC = np.mean(results[:15, 1:, :], axis=(0, 1))
  125. else:
  126. mPC = np.mean(results[:, 1:, :], axis=(0, 1))
  127. rPC = mPC / P
  128. print(f'\nmodel: {osp.basename(filename)}')
  129. if 'P' in prints:
  130. print(f'Performance on Clean Data [P] in AP50 = {np.mean(P):0.3f}')
  131. if 'mPC' in prints:
  132. print('Mean Performance under Corruption [mPC] in AP50 = '
  133. f'{np.mean(mPC):0.3f}')
  134. if 'rPC' in prints:
  135. print('Relative Performance under Corruption [rPC] in % = '
  136. f'{np.mean(rPC) * 100:0.1f}')
  137. return np.mean(results, axis=2, keepdims=True)
  138. def get_results(filename,
  139. dataset='coco',
  140. task='bbox',
  141. metric=None,
  142. prints='mPC',
  143. aggregate='benchmark'):
  144. assert dataset in ['coco', 'voc', 'cityscapes']
  145. if dataset in ['coco', 'cityscapes']:
  146. results = get_coco_style_results(
  147. filename,
  148. task=task,
  149. metric=metric,
  150. prints=prints,
  151. aggregate=aggregate)
  152. elif dataset == 'voc':
  153. if task != 'bbox':
  154. print('Only bbox analysis is supported for Pascal VOC')
  155. print('Will report bbox results\n')
  156. if metric not in [None, ['AP'], ['AP50']]:
  157. print('Only the AP50 metric is supported for Pascal VOC')
  158. print('Will report AP50 metric\n')
  159. results = get_voc_style_results(
  160. filename, prints=prints, aggregate=aggregate)
  161. return results
  162. def get_distortions_from_file(filename):
  163. eval_output = load(filename)
  164. return get_distortions_from_results(eval_output)
  165. def get_distortions_from_results(eval_output):
  166. distortions = []
  167. for i, distortion in enumerate(eval_output):
  168. distortions.append(distortion.replace('_', ' '))
  169. return distortions
  170. def main():
  171. parser = ArgumentParser(description='Corruption Result Analysis')
  172. parser.add_argument('filename', help='result file path')
  173. parser.add_argument(
  174. '--dataset',
  175. type=str,
  176. choices=['coco', 'voc', 'cityscapes'],
  177. default='coco',
  178. help='dataset type')
  179. parser.add_argument(
  180. '--task',
  181. type=str,
  182. nargs='+',
  183. choices=['bbox', 'segm'],
  184. default=['bbox'],
  185. help='task to report')
  186. parser.add_argument(
  187. '--metric',
  188. nargs='+',
  189. choices=[
  190. None, 'AP', 'AP50', 'AP75', 'APs', 'APm', 'APl', 'AR1', 'AR10',
  191. 'AR100', 'ARs', 'ARm', 'ARl'
  192. ],
  193. default=None,
  194. help='metric to report')
  195. parser.add_argument(
  196. '--prints',
  197. type=str,
  198. nargs='+',
  199. choices=['P', 'mPC', 'rPC'],
  200. default='mPC',
  201. help='corruption benchmark metric to print')
  202. parser.add_argument(
  203. '--aggregate',
  204. type=str,
  205. choices=['all', 'benchmark'],
  206. default='benchmark',
  207. help='aggregate all results or only those \
  208. for benchmark corruptions')
  209. args = parser.parse_args()
  210. for task in args.task:
  211. get_results(
  212. args.filename,
  213. dataset=args.dataset,
  214. task=task,
  215. metric=args.metric,
  216. prints=args.prints,
  217. aggregate=args.aggregate)
  218. if __name__ == '__main__':
  219. main()