detectron2_to_mmdet.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. from collections import OrderedDict
  4. import torch
  5. from mmengine.fileio import load
  6. from mmengine.runner import save_checkpoint
  7. def convert(src: str, dst: str, prefix: str = 'd2_model') -> None:
  8. """Convert Detectron2 checkpoint to MMDetection style.
  9. Args:
  10. src (str): The Detectron2 checkpoint path, should endswith `pkl`.
  11. dst (str): The MMDetection checkpoint path.
  12. prefix (str): The prefix of MMDetection model, defaults to 'd2_model'.
  13. """
  14. # load arch_settings
  15. assert src.endswith('pkl'), \
  16. 'the source Detectron2 checkpoint should endswith `pkl`.'
  17. d2_model = load(src, encoding='latin1').get('model')
  18. assert d2_model is not None
  19. # convert to mmdet style
  20. dst_state_dict = OrderedDict()
  21. for name, value in d2_model.items():
  22. if not isinstance(value, torch.Tensor):
  23. value = torch.from_numpy(value)
  24. dst_state_dict[f'{prefix}.{name}'] = value
  25. mmdet_model = dict(state_dict=dst_state_dict, meta=dict())
  26. save_checkpoint(mmdet_model, dst)
  27. print(f'Convert Detectron2 model {src} to MMDetection model {dst}')
  28. def main():
  29. parser = argparse.ArgumentParser(
  30. description='Convert Detectron2 checkpoint to MMDetection style')
  31. parser.add_argument('src', help='Detectron2 model path')
  32. parser.add_argument('dst', help='MMDetectron model save path')
  33. parser.add_argument(
  34. '--prefix', default='d2_model', type=str, help='prefix of the model')
  35. args = parser.parse_args()
  36. convert(args.src, args.dst, args.prefix)
  37. if __name__ == '__main__':
  38. main()