upgrade_ssd_version.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import tempfile
  4. from collections import OrderedDict
  5. import torch
  6. from mmengine import Config
  7. def parse_config(config_strings):
  8. temp_file = tempfile.NamedTemporaryFile()
  9. config_path = f'{temp_file.name}.py'
  10. with open(config_path, 'w') as f:
  11. f.write(config_strings)
  12. config = Config.fromfile(config_path)
  13. # check whether it is SSD
  14. if config.model.bbox_head.type != 'SSDHead':
  15. raise AssertionError('This is not a SSD model.')
  16. def convert(in_file, out_file):
  17. checkpoint = torch.load(in_file)
  18. in_state_dict = checkpoint.pop('state_dict')
  19. out_state_dict = OrderedDict()
  20. meta_info = checkpoint['meta']
  21. parse_config('#' + meta_info['config'])
  22. for key, value in in_state_dict.items():
  23. if 'extra' in key:
  24. layer_idx = int(key.split('.')[2])
  25. new_key = 'neck.extra_layers.{}.{}.conv.'.format(
  26. layer_idx // 2, layer_idx % 2) + key.split('.')[-1]
  27. elif 'l2_norm' in key:
  28. new_key = 'neck.l2_norm.weight'
  29. elif 'bbox_head' in key:
  30. new_key = key[:21] + '.0' + key[21:]
  31. else:
  32. new_key = key
  33. out_state_dict[new_key] = value
  34. checkpoint['state_dict'] = out_state_dict
  35. if torch.__version__ >= '1.6':
  36. torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False)
  37. else:
  38. torch.save(checkpoint, out_file)
  39. def main():
  40. parser = argparse.ArgumentParser(description='Upgrade SSD version')
  41. parser.add_argument('in_file', help='input checkpoint file')
  42. parser.add_argument('out_file', help='output checkpoint file')
  43. args = parser.parse_args()
  44. convert(args.in_file, args.out_file)
  45. if __name__ == '__main__':
  46. main()