regnet2mmdet.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. from collections import OrderedDict
  4. import torch
  5. def convert_stem(model_key, model_weight, state_dict, converted_names):
  6. new_key = model_key.replace('stem.conv', 'conv1')
  7. new_key = new_key.replace('stem.bn', 'bn1')
  8. state_dict[new_key] = model_weight
  9. converted_names.add(model_key)
  10. print(f'Convert {model_key} to {new_key}')
  11. def convert_head(model_key, model_weight, state_dict, converted_names):
  12. new_key = model_key.replace('head.fc', 'fc')
  13. state_dict[new_key] = model_weight
  14. converted_names.add(model_key)
  15. print(f'Convert {model_key} to {new_key}')
  16. def convert_reslayer(model_key, model_weight, state_dict, converted_names):
  17. split_keys = model_key.split('.')
  18. layer, block, module = split_keys[:3]
  19. block_id = int(block[1:])
  20. layer_name = f'layer{int(layer[1:])}'
  21. block_name = f'{block_id - 1}'
  22. if block_id == 1 and module == 'bn':
  23. new_key = f'{layer_name}.{block_name}.downsample.1.{split_keys[-1]}'
  24. elif block_id == 1 and module == 'proj':
  25. new_key = f'{layer_name}.{block_name}.downsample.0.{split_keys[-1]}'
  26. elif module == 'f':
  27. if split_keys[3] == 'a_bn':
  28. module_name = 'bn1'
  29. elif split_keys[3] == 'b_bn':
  30. module_name = 'bn2'
  31. elif split_keys[3] == 'c_bn':
  32. module_name = 'bn3'
  33. elif split_keys[3] == 'a':
  34. module_name = 'conv1'
  35. elif split_keys[3] == 'b':
  36. module_name = 'conv2'
  37. elif split_keys[3] == 'c':
  38. module_name = 'conv3'
  39. new_key = f'{layer_name}.{block_name}.{module_name}.{split_keys[-1]}'
  40. else:
  41. raise ValueError(f'Unsupported conversion of key {model_key}')
  42. print(f'Convert {model_key} to {new_key}')
  43. state_dict[new_key] = model_weight
  44. converted_names.add(model_key)
  45. def convert(src, dst):
  46. """Convert keys in pycls pretrained RegNet models to mmdet style."""
  47. # load caffe model
  48. regnet_model = torch.load(src)
  49. blobs = regnet_model['model_state']
  50. # convert to pytorch style
  51. state_dict = OrderedDict()
  52. converted_names = set()
  53. for key, weight in blobs.items():
  54. if 'stem' in key:
  55. convert_stem(key, weight, state_dict, converted_names)
  56. elif 'head' in key:
  57. convert_head(key, weight, state_dict, converted_names)
  58. elif key.startswith('s'):
  59. convert_reslayer(key, weight, state_dict, converted_names)
  60. # check if all layers are converted
  61. for key in blobs:
  62. if key not in converted_names:
  63. print(f'not converted: {key}')
  64. # save checkpoint
  65. checkpoint = dict()
  66. checkpoint['state_dict'] = state_dict
  67. torch.save(checkpoint, dst)
  68. def main():
  69. parser = argparse.ArgumentParser(description='Convert model keys')
  70. parser.add_argument('src', help='src detectron model path')
  71. parser.add_argument('dst', help='save path')
  72. args = parser.parse_args()
  73. convert(args.src, args.dst)
  74. if __name__ == '__main__':
  75. main()