detectron2pytorch.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. from collections import OrderedDict
  4. import torch
  5. from mmengine.fileio import load
  6. arch_settings = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3)}
  7. def convert_bn(blobs, state_dict, caffe_name, torch_name, converted_names):
  8. # detectron replace bn with affine channel layer
  9. state_dict[torch_name + '.bias'] = torch.from_numpy(blobs[caffe_name +
  10. '_b'])
  11. state_dict[torch_name + '.weight'] = torch.from_numpy(blobs[caffe_name +
  12. '_s'])
  13. bn_size = state_dict[torch_name + '.weight'].size()
  14. state_dict[torch_name + '.running_mean'] = torch.zeros(bn_size)
  15. state_dict[torch_name + '.running_var'] = torch.ones(bn_size)
  16. converted_names.add(caffe_name + '_b')
  17. converted_names.add(caffe_name + '_s')
  18. def convert_conv_fc(blobs, state_dict, caffe_name, torch_name,
  19. converted_names):
  20. state_dict[torch_name + '.weight'] = torch.from_numpy(blobs[caffe_name +
  21. '_w'])
  22. converted_names.add(caffe_name + '_w')
  23. if caffe_name + '_b' in blobs:
  24. state_dict[torch_name + '.bias'] = torch.from_numpy(blobs[caffe_name +
  25. '_b'])
  26. converted_names.add(caffe_name + '_b')
  27. def convert(src, dst, depth):
  28. """Convert keys in detectron pretrained ResNet models to pytorch style."""
  29. # load arch_settings
  30. if depth not in arch_settings:
  31. raise ValueError('Only support ResNet-50 and ResNet-101 currently')
  32. block_nums = arch_settings[depth]
  33. # load caffe model
  34. caffe_model = load(src, encoding='latin1')
  35. blobs = caffe_model['blobs'] if 'blobs' in caffe_model else caffe_model
  36. # convert to pytorch style
  37. state_dict = OrderedDict()
  38. converted_names = set()
  39. convert_conv_fc(blobs, state_dict, 'conv1', 'conv1', converted_names)
  40. convert_bn(blobs, state_dict, 'res_conv1_bn', 'bn1', converted_names)
  41. for i in range(1, len(block_nums) + 1):
  42. for j in range(block_nums[i - 1]):
  43. if j == 0:
  44. convert_conv_fc(blobs, state_dict, f'res{i + 1}_{j}_branch1',
  45. f'layer{i}.{j}.downsample.0', converted_names)
  46. convert_bn(blobs, state_dict, f'res{i + 1}_{j}_branch1_bn',
  47. f'layer{i}.{j}.downsample.1', converted_names)
  48. for k, letter in enumerate(['a', 'b', 'c']):
  49. convert_conv_fc(blobs, state_dict,
  50. f'res{i + 1}_{j}_branch2{letter}',
  51. f'layer{i}.{j}.conv{k+1}', converted_names)
  52. convert_bn(blobs, state_dict,
  53. f'res{i + 1}_{j}_branch2{letter}_bn',
  54. f'layer{i}.{j}.bn{k + 1}', converted_names)
  55. # check if all layers are converted
  56. for key in blobs:
  57. if key not in converted_names:
  58. print(f'Not Convert: {key}')
  59. # save checkpoint
  60. checkpoint = dict()
  61. checkpoint['state_dict'] = state_dict
  62. torch.save(checkpoint, dst)
  63. def main():
  64. parser = argparse.ArgumentParser(description='Convert model keys')
  65. parser.add_argument('src', help='src detectron model path')
  66. parser.add_argument('dst', help='save path')
  67. parser.add_argument('depth', type=int, help='ResNet model depth')
  68. args = parser.parse_args()
  69. convert(args.src, args.dst, args.depth)
  70. if __name__ == '__main__':
  71. main()