download_checkpoints.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import math
  4. import os
  5. import os.path as osp
  6. from multiprocessing import Pool
  7. import torch
  8. from mmengine.config import Config
  9. from mmengine.utils import mkdir_or_exist
  10. def download(url, out_file, min_bytes=math.pow(1024, 2), progress=True):
  11. # math.pow(1024, 2) is mean 1 MB
  12. assert_msg = f"Downloaded url '{url}' does not exist " \
  13. f'or size is < min_bytes={min_bytes}'
  14. try:
  15. print(f'Downloading {url} to {out_file}...')
  16. torch.hub.download_url_to_file(url, str(out_file), progress=progress)
  17. assert osp.exists(
  18. out_file) and osp.getsize(out_file) > min_bytes, assert_msg
  19. except Exception as e:
  20. if osp.exists(out_file):
  21. os.remove(out_file)
  22. print(f'ERROR: {e}\nRe-attempting {url} to {out_file} ...')
  23. os.system(f"curl -L '{url}' -o '{out_file}' --retry 3 -C -"
  24. ) # curl download, retry and resume on fail
  25. finally:
  26. if osp.exists(out_file) and osp.getsize(out_file) < min_bytes:
  27. os.remove(out_file) # remove partial downloads
  28. if not osp.exists(out_file):
  29. print(f'ERROR: {assert_msg}\n')
  30. print('=========================================\n')
  31. def parse_args():
  32. parser = argparse.ArgumentParser(description='Download checkpoints')
  33. parser.add_argument('config', help='test config file path')
  34. parser.add_argument(
  35. 'out', type=str, help='output dir of checkpoints to be stored')
  36. parser.add_argument(
  37. '--nproc', type=int, default=16, help='num of Processes')
  38. parser.add_argument(
  39. '--intranet',
  40. action='store_true',
  41. help='switch to internal network url')
  42. args = parser.parse_args()
  43. return args
  44. if __name__ == '__main__':
  45. args = parse_args()
  46. mkdir_or_exist(args.out)
  47. cfg = Config.fromfile(args.config)
  48. checkpoint_url_list = []
  49. checkpoint_out_list = []
  50. for model in cfg:
  51. model_infos = cfg[model]
  52. if not isinstance(model_infos, list):
  53. model_infos = [model_infos]
  54. for model_info in model_infos:
  55. checkpoint = model_info['checkpoint']
  56. out_file = osp.join(args.out, checkpoint)
  57. if not osp.exists(out_file):
  58. url = model_info['url']
  59. if args.intranet is True:
  60. url = url.replace('.com', '.sensetime.com')
  61. url = url.replace('https', 'http')
  62. checkpoint_url_list.append(url)
  63. checkpoint_out_list.append(out_file)
  64. if len(checkpoint_url_list) > 0:
  65. pool = Pool(min(os.cpu_count(), args.nproc))
  66. pool.starmap(download, zip(checkpoint_url_list, checkpoint_out_list))
  67. else:
  68. print('No files to download!')