test_robustness.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import copy
  4. import os
  5. import os.path as osp
  6. from mmengine.config import Config, DictAction
  7. from mmengine.dist import get_dist_info
  8. from mmengine.evaluator import DumpResults
  9. from mmengine.fileio import dump
  10. from mmengine.runner import Runner
  11. from mmdet.engine.hooks.utils import trigger_visualization_hook
  12. from mmdet.registry import RUNNERS
  13. from tools.analysis_tools.robustness_eval import get_results
  14. def parse_args():
  15. parser = argparse.ArgumentParser(description='MMDet test detector')
  16. parser.add_argument('config', help='test config file path')
  17. parser.add_argument('checkpoint', help='checkpoint file')
  18. parser.add_argument(
  19. '--out',
  20. type=str,
  21. help='dump predictions to a pickle file for offline evaluation')
  22. parser.add_argument(
  23. '--corruptions',
  24. type=str,
  25. nargs='+',
  26. default='benchmark',
  27. choices=[
  28. 'all', 'benchmark', 'noise', 'blur', 'weather', 'digital',
  29. 'holdout', 'None', 'gaussian_noise', 'shot_noise', 'impulse_noise',
  30. 'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur', 'snow',
  31. 'frost', 'fog', 'brightness', 'contrast', 'elastic_transform',
  32. 'pixelate', 'jpeg_compression', 'speckle_noise', 'gaussian_blur',
  33. 'spatter', 'saturate'
  34. ],
  35. help='corruptions')
  36. parser.add_argument(
  37. '--work-dir',
  38. help='the directory to save the file containing evaluation metrics')
  39. parser.add_argument(
  40. '--severities',
  41. type=int,
  42. nargs='+',
  43. default=[0, 1, 2, 3, 4, 5],
  44. help='corruption severity levels')
  45. parser.add_argument(
  46. '--summaries',
  47. type=bool,
  48. default=False,
  49. help='Print summaries for every corruption and severity')
  50. parser.add_argument('--show', action='store_true', help='show results')
  51. parser.add_argument(
  52. '--show-dir', help='directory where painted images will be saved')
  53. parser.add_argument(
  54. '--wait-time', type=float, default=2, help='the interval of show (s)')
  55. parser.add_argument('--seed', type=int, default=None, help='random seed')
  56. parser.add_argument(
  57. '--launcher',
  58. choices=['none', 'pytorch', 'slurm', 'mpi'],
  59. default='none',
  60. help='job launcher')
  61. parser.add_argument('--local_rank', type=int, default=0)
  62. parser.add_argument(
  63. '--final-prints',
  64. type=str,
  65. nargs='+',
  66. choices=['P', 'mPC', 'rPC'],
  67. default='mPC',
  68. help='corruption benchmark metric to print at the end')
  69. parser.add_argument(
  70. '--final-prints-aggregate',
  71. type=str,
  72. choices=['all', 'benchmark'],
  73. default='benchmark',
  74. help='aggregate all results or only those for benchmark corruptions')
  75. parser.add_argument(
  76. '--cfg-options',
  77. nargs='+',
  78. action=DictAction,
  79. help='override some settings in the used config, the key-value pair '
  80. 'in xxx=yyy format will be merged into config file. If the value to '
  81. 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
  82. 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
  83. 'Note that the quotation marks are necessary and that no white space '
  84. 'is allowed.')
  85. args = parser.parse_args()
  86. if 'LOCAL_RANK' not in os.environ:
  87. os.environ['LOCAL_RANK'] = str(args.local_rank)
  88. return args
  89. def main():
  90. args = parse_args()
  91. assert args.out or args.show or args.show_dir, \
  92. ('Please specify at least one operation (save or show the results) '
  93. 'with the argument "--out", "--show" or "show-dir"')
  94. # load config
  95. cfg = Config.fromfile(args.config)
  96. cfg.launcher = args.launcher
  97. if args.cfg_options is not None:
  98. cfg.merge_from_dict(args.cfg_options)
  99. # work_dir is determined in this priority: CLI > segment in file > filename
  100. if args.work_dir is not None:
  101. # update configs according to CLI args if args.work_dir is not None
  102. cfg.work_dir = args.work_dir
  103. elif cfg.get('work_dir', None) is None:
  104. # use config filename as default work_dir if cfg.work_dir is None
  105. cfg.work_dir = osp.join('./work_dirs',
  106. osp.splitext(osp.basename(args.config))[0])
  107. cfg.model.backbone.init_cfg.type = None
  108. cfg.test_dataloader.dataset.test_mode = True
  109. cfg.load_from = args.checkpoint
  110. if args.show or args.show_dir:
  111. cfg = trigger_visualization_hook(cfg, args)
  112. # build the runner from config
  113. if 'runner_type' not in cfg:
  114. # build the default runner
  115. runner = Runner.from_cfg(cfg)
  116. else:
  117. # build customized runner from the registry
  118. # if 'runner_type' is set in the cfg
  119. runner = RUNNERS.build(cfg)
  120. # add `DumpResults` dummy metric
  121. if args.out is not None:
  122. assert args.out.endswith(('.pkl', '.pickle')), \
  123. 'The dump file must be a pkl file.'
  124. runner.test_evaluator.metrics.append(
  125. DumpResults(out_file_path=args.out))
  126. if 'all' in args.corruptions:
  127. corruptions = [
  128. 'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
  129. 'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
  130. 'brightness', 'contrast', 'elastic_transform', 'pixelate',
  131. 'jpeg_compression', 'speckle_noise', 'gaussian_blur', 'spatter',
  132. 'saturate'
  133. ]
  134. elif 'benchmark' in args.corruptions:
  135. corruptions = [
  136. 'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
  137. 'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
  138. 'brightness', 'contrast', 'elastic_transform', 'pixelate',
  139. 'jpeg_compression'
  140. ]
  141. elif 'noise' in args.corruptions:
  142. corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise']
  143. elif 'blur' in args.corruptions:
  144. corruptions = [
  145. 'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur'
  146. ]
  147. elif 'weather' in args.corruptions:
  148. corruptions = ['snow', 'frost', 'fog', 'brightness']
  149. elif 'digital' in args.corruptions:
  150. corruptions = [
  151. 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression'
  152. ]
  153. elif 'holdout' in args.corruptions:
  154. corruptions = ['speckle_noise', 'gaussian_blur', 'spatter', 'saturate']
  155. elif 'None' in args.corruptions:
  156. corruptions = ['None']
  157. args.severities = [0]
  158. else:
  159. corruptions = args.corruptions
  160. aggregated_results = {}
  161. for corr_i, corruption in enumerate(corruptions):
  162. aggregated_results[corruption] = {}
  163. for sev_i, corruption_severity in enumerate(args.severities):
  164. # evaluate severity 0 (= no corruption) only once
  165. if corr_i > 0 and corruption_severity == 0:
  166. aggregated_results[corruption][0] = \
  167. aggregated_results[corruptions[0]][0]
  168. continue
  169. test_loader_cfg = copy.deepcopy(cfg.test_dataloader)
  170. # assign corruption and severity
  171. if corruption_severity > 0:
  172. corruption_trans = dict(
  173. type='Corrupt',
  174. corruption=corruption,
  175. severity=corruption_severity)
  176. # TODO: hard coded "1", we assume that the first step is
  177. # loading images, which needs to be fixed in the future
  178. test_loader_cfg.dataset.pipeline.insert(1, corruption_trans)
  179. test_loader = runner.build_dataloader(test_loader_cfg)
  180. runner.test_loop.dataloader = test_loader
  181. # set random seeds
  182. if args.seed is not None:
  183. runner.set_randomness(args.seed)
  184. # print info
  185. print(f'\nTesting {corruption} at severity {corruption_severity}')
  186. eval_results = runner.test()
  187. if args.out:
  188. eval_results_filename = (
  189. osp.splitext(args.out)[0] + '_results' +
  190. osp.splitext(args.out)[1])
  191. aggregated_results[corruption][
  192. corruption_severity] = eval_results
  193. dump(aggregated_results, eval_results_filename)
  194. rank, _ = get_dist_info()
  195. if rank == 0:
  196. eval_results_filename = (
  197. osp.splitext(args.out)[0] + '_results' + osp.splitext(args.out)[1])
  198. # print final results
  199. print('\nAggregated results:')
  200. prints = args.final_prints
  201. aggregate = args.final_prints_aggregate
  202. if cfg.dataset_type == 'VOCDataset':
  203. get_results(
  204. eval_results_filename,
  205. dataset='voc',
  206. prints=prints,
  207. aggregate=aggregate)
  208. else:
  209. get_results(
  210. eval_results_filename,
  211. dataset='coco',
  212. prints=prints,
  213. aggregate=aggregate)
  214. if __name__ == '__main__':
  215. main()