selfsup2mmdet.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. from collections import OrderedDict
  4. import torch
  5. def moco_convert(src, dst):
  6. """Convert keys in pycls pretrained moco models to mmdet style."""
  7. # load caffe model
  8. moco_model = torch.load(src)
  9. blobs = moco_model['state_dict']
  10. # convert to pytorch style
  11. state_dict = OrderedDict()
  12. for k, v in blobs.items():
  13. if not k.startswith('module.encoder_q.'):
  14. continue
  15. old_k = k
  16. k = k.replace('module.encoder_q.', '')
  17. state_dict[k] = v
  18. print(old_k, '->', k)
  19. # save checkpoint
  20. checkpoint = dict()
  21. checkpoint['state_dict'] = state_dict
  22. torch.save(checkpoint, dst)
  23. def main():
  24. parser = argparse.ArgumentParser(description='Convert model keys')
  25. parser.add_argument('src', help='src detectron model path')
  26. parser.add_argument('dst', help='save path')
  27. parser.add_argument(
  28. '--selfsup', type=str, choices=['moco', 'swav'], help='save path')
  29. args = parser.parse_args()
  30. if args.selfsup == 'moco':
  31. moco_convert(args.src, args.dst)
  32. elif args.selfsup == 'swav':
  33. print('SWAV does not need to convert the keys')
  34. if __name__ == '__main__':
  35. main()