gather_models.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import glob
  4. import json
  5. import os.path as osp
  6. import shutil
  7. import subprocess
  8. from collections import OrderedDict
  9. import torch
  10. import yaml
  11. from mmengine.config import Config
  12. from mmengine.fileio import dump
  13. from mmengine.utils import mkdir_or_exist, scandir
  14. def ordered_yaml_dump(data, stream=None, Dumper=yaml.SafeDumper, **kwds):
  15. class OrderedDumper(Dumper):
  16. pass
  17. def _dict_representer(dumper, data):
  18. return dumper.represent_mapping(
  19. yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, data.items())
  20. OrderedDumper.add_representer(OrderedDict, _dict_representer)
  21. return yaml.dump(data, stream, OrderedDumper, **kwds)
  22. def process_checkpoint(in_file, out_file):
  23. checkpoint = torch.load(in_file, map_location='cpu')
  24. # remove optimizer for smaller file size
  25. if 'optimizer' in checkpoint:
  26. del checkpoint['optimizer']
  27. # remove ema state_dict
  28. for key in list(checkpoint['state_dict']):
  29. if key.startswith('ema_'):
  30. checkpoint['state_dict'].pop(key)
  31. # if it is necessary to remove some sensitive data in checkpoint['meta'],
  32. # add the code here.
  33. if torch.__version__ >= '1.6':
  34. torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False)
  35. else:
  36. torch.save(checkpoint, out_file)
  37. sha = subprocess.check_output(['sha256sum', out_file]).decode()
  38. final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8])
  39. subprocess.Popen(['mv', out_file, final_file])
  40. return final_file
  41. def is_by_epoch(config):
  42. cfg = Config.fromfile('./configs/' + config)
  43. return cfg.runner.type == 'EpochBasedRunner'
  44. def get_final_epoch_or_iter(config):
  45. cfg = Config.fromfile('./configs/' + config)
  46. if cfg.runner.type == 'EpochBasedRunner':
  47. return cfg.runner.max_epochs
  48. else:
  49. return cfg.runner.max_iters
  50. def get_best_epoch_or_iter(exp_dir):
  51. best_epoch_iter_full_path = list(
  52. sorted(glob.glob(osp.join(exp_dir, 'best_*.pth'))))[-1]
  53. best_epoch_or_iter_model_path = best_epoch_iter_full_path.split('/')[-1]
  54. best_epoch_or_iter = best_epoch_or_iter_model_path.\
  55. split('_')[-1].split('.')[0]
  56. return best_epoch_or_iter_model_path, int(best_epoch_or_iter)
  57. def get_real_epoch_or_iter(config):
  58. cfg = Config.fromfile('./configs/' + config)
  59. if cfg.runner.type == 'EpochBasedRunner':
  60. epoch = cfg.runner.max_epochs
  61. if cfg.data.train.type == 'RepeatDataset':
  62. epoch *= cfg.data.train.times
  63. return epoch
  64. else:
  65. return cfg.runner.max_iters
  66. def get_final_results(log_json_path,
  67. epoch_or_iter,
  68. results_lut,
  69. by_epoch=True):
  70. result_dict = dict()
  71. last_val_line = None
  72. last_train_line = None
  73. last_val_line_idx = -1
  74. last_train_line_idx = -1
  75. with open(log_json_path, 'r') as f:
  76. for i, line in enumerate(f.readlines()):
  77. log_line = json.loads(line)
  78. if 'mode' not in log_line.keys():
  79. continue
  80. if by_epoch:
  81. if (log_line['mode'] == 'train'
  82. and log_line['epoch'] == epoch_or_iter):
  83. result_dict['memory'] = log_line['memory']
  84. if (log_line['mode'] == 'val'
  85. and log_line['epoch'] == epoch_or_iter):
  86. result_dict.update({
  87. key: log_line[key]
  88. for key in results_lut if key in log_line
  89. })
  90. return result_dict
  91. else:
  92. if log_line['mode'] == 'train':
  93. last_train_line_idx = i
  94. last_train_line = log_line
  95. if log_line and log_line['mode'] == 'val':
  96. last_val_line_idx = i
  97. last_val_line = log_line
  98. # bug: max_iters = 768, last_train_line['iter'] = 750
  99. assert last_val_line_idx == last_train_line_idx + 1, \
  100. 'Log file is incomplete'
  101. result_dict['memory'] = last_train_line['memory']
  102. result_dict.update({
  103. key: last_val_line[key]
  104. for key in results_lut if key in last_val_line
  105. })
  106. return result_dict
  107. def get_dataset_name(config):
  108. # If there are more dataset, add here.
  109. name_map = dict(
  110. CityscapesDataset='Cityscapes',
  111. CocoDataset='COCO',
  112. CocoPanopticDataset='COCO',
  113. DeepFashionDataset='Deep Fashion',
  114. LVISV05Dataset='LVIS v0.5',
  115. LVISV1Dataset='LVIS v1',
  116. VOCDataset='Pascal VOC',
  117. WIDERFaceDataset='WIDER Face',
  118. OpenImagesDataset='OpenImagesDataset',
  119. OpenImagesChallengeDataset='OpenImagesChallengeDataset',
  120. Objects365V1Dataset='Objects365 v1',
  121. Objects365V2Dataset='Objects365 v2')
  122. cfg = Config.fromfile('./configs/' + config)
  123. return name_map[cfg.dataset_type]
  124. def convert_model_info_to_pwc(model_infos):
  125. pwc_files = {}
  126. for model in model_infos:
  127. cfg_folder_name = osp.split(model['config'])[-2]
  128. pwc_model_info = OrderedDict()
  129. pwc_model_info['Name'] = osp.split(model['config'])[-1].split('.')[0]
  130. pwc_model_info['In Collection'] = 'Please fill in Collection name'
  131. pwc_model_info['Config'] = osp.join('configs', model['config'])
  132. # get metadata
  133. memory = round(model['results']['memory'] / 1024, 1)
  134. meta_data = OrderedDict()
  135. meta_data['Training Memory (GB)'] = memory
  136. if 'epochs' in model:
  137. meta_data['Epochs'] = get_real_epoch_or_iter(model['config'])
  138. else:
  139. meta_data['Iterations'] = get_real_epoch_or_iter(model['config'])
  140. pwc_model_info['Metadata'] = meta_data
  141. # get dataset name
  142. dataset_name = get_dataset_name(model['config'])
  143. # get results
  144. results = []
  145. # if there are more metrics, add here.
  146. if 'bbox_mAP' in model['results']:
  147. metric = round(model['results']['bbox_mAP'] * 100, 1)
  148. results.append(
  149. OrderedDict(
  150. Task='Object Detection',
  151. Dataset=dataset_name,
  152. Metrics={'box AP': metric}))
  153. if 'segm_mAP' in model['results']:
  154. metric = round(model['results']['segm_mAP'] * 100, 1)
  155. results.append(
  156. OrderedDict(
  157. Task='Instance Segmentation',
  158. Dataset=dataset_name,
  159. Metrics={'mask AP': metric}))
  160. if 'PQ' in model['results']:
  161. metric = round(model['results']['PQ'], 1)
  162. results.append(
  163. OrderedDict(
  164. Task='Panoptic Segmentation',
  165. Dataset=dataset_name,
  166. Metrics={'PQ': metric}))
  167. pwc_model_info['Results'] = results
  168. link_string = 'https://download.openmmlab.com/mmdetection/v2.0/'
  169. link_string += '{}/{}'.format(model['config'].rstrip('.py'),
  170. osp.split(model['model_path'])[-1])
  171. pwc_model_info['Weights'] = link_string
  172. if cfg_folder_name in pwc_files:
  173. pwc_files[cfg_folder_name].append(pwc_model_info)
  174. else:
  175. pwc_files[cfg_folder_name] = [pwc_model_info]
  176. return pwc_files
  177. def parse_args():
  178. parser = argparse.ArgumentParser(description='Gather benchmarked models')
  179. parser.add_argument(
  180. 'root',
  181. type=str,
  182. help='root path of benchmarked models to be gathered')
  183. parser.add_argument(
  184. 'out', type=str, help='output path of gathered models to be stored')
  185. parser.add_argument(
  186. '--best',
  187. action='store_true',
  188. help='whether to gather the best model.')
  189. args = parser.parse_args()
  190. return args
  191. def main():
  192. args = parse_args()
  193. models_root = args.root
  194. models_out = args.out
  195. mkdir_or_exist(models_out)
  196. # find all models in the root directory to be gathered
  197. raw_configs = list(scandir('./configs', '.py', recursive=True))
  198. # filter configs that is not trained in the experiments dir
  199. used_configs = []
  200. for raw_config in raw_configs:
  201. if osp.exists(osp.join(models_root, raw_config)):
  202. used_configs.append(raw_config)
  203. print(f'Find {len(used_configs)} models to be gathered')
  204. # find final_ckpt and log file for trained each config
  205. # and parse the best performance
  206. model_infos = []
  207. for used_config in used_configs:
  208. exp_dir = osp.join(models_root, used_config)
  209. by_epoch = is_by_epoch(used_config)
  210. # check whether the exps is finished
  211. if args.best is True:
  212. final_model, final_epoch_or_iter = get_best_epoch_or_iter(exp_dir)
  213. else:
  214. final_epoch_or_iter = get_final_epoch_or_iter(used_config)
  215. final_model = '{}_{}.pth'.format('epoch' if by_epoch else 'iter',
  216. final_epoch_or_iter)
  217. model_path = osp.join(exp_dir, final_model)
  218. # skip if the model is still training
  219. if not osp.exists(model_path):
  220. continue
  221. # get the latest logs
  222. log_json_path = list(
  223. sorted(glob.glob(osp.join(exp_dir, '*.log.json'))))[-1]
  224. log_txt_path = list(sorted(glob.glob(osp.join(exp_dir, '*.log'))))[-1]
  225. cfg = Config.fromfile('./configs/' + used_config)
  226. results_lut = cfg.evaluation.metric
  227. if not isinstance(results_lut, list):
  228. results_lut = [results_lut]
  229. # case when using VOC, the evaluation key is only 'mAP'
  230. # when using Panoptic Dataset, the evaluation key is 'PQ'.
  231. for i, key in enumerate(results_lut):
  232. if 'mAP' not in key and 'PQ' not in key:
  233. results_lut[i] = key + '_mAP'
  234. model_performance = get_final_results(log_json_path,
  235. final_epoch_or_iter, results_lut,
  236. by_epoch)
  237. if model_performance is None:
  238. continue
  239. model_time = osp.split(log_txt_path)[-1].split('.')[0]
  240. model_info = dict(
  241. config=used_config,
  242. results=model_performance,
  243. model_time=model_time,
  244. final_model=final_model,
  245. log_json_path=osp.split(log_json_path)[-1])
  246. model_info['epochs' if by_epoch else 'iterations'] =\
  247. final_epoch_or_iter
  248. model_infos.append(model_info)
  249. # publish model for each checkpoint
  250. publish_model_infos = []
  251. for model in model_infos:
  252. model_publish_dir = osp.join(models_out, model['config'].rstrip('.py'))
  253. mkdir_or_exist(model_publish_dir)
  254. model_name = osp.split(model['config'])[-1].split('.')[0]
  255. model_name += '_' + model['model_time']
  256. publish_model_path = osp.join(model_publish_dir, model_name)
  257. trained_model_path = osp.join(models_root, model['config'],
  258. model['final_model'])
  259. # convert model
  260. final_model_path = process_checkpoint(trained_model_path,
  261. publish_model_path)
  262. # copy log
  263. shutil.copy(
  264. osp.join(models_root, model['config'], model['log_json_path']),
  265. osp.join(model_publish_dir, f'{model_name}.log.json'))
  266. shutil.copy(
  267. osp.join(models_root, model['config'],
  268. model['log_json_path'].rstrip('.json')),
  269. osp.join(model_publish_dir, f'{model_name}.log'))
  270. # copy config to guarantee reproducibility
  271. config_path = model['config']
  272. config_path = osp.join(
  273. 'configs',
  274. config_path) if 'configs' not in config_path else config_path
  275. target_config_path = osp.split(config_path)[-1]
  276. shutil.copy(config_path, osp.join(model_publish_dir,
  277. target_config_path))
  278. model['model_path'] = final_model_path
  279. publish_model_infos.append(model)
  280. models = dict(models=publish_model_infos)
  281. print(f'Totally gathered {len(publish_model_infos)} models')
  282. dump(models, osp.join(models_out, 'model_info.json'))
  283. pwc_files = convert_model_info_to_pwc(publish_model_infos)
  284. for name in pwc_files:
  285. with open(osp.join(models_out, name + '_metafile.yml'), 'w') as f:
  286. ordered_yaml_dump(pwc_files[name], f, encoding='utf-8')
  287. if __name__ == '__main__':
  288. main()