1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import argparse
- from collections import OrderedDict
- import torch
- from mmengine.fileio import load
- arch_settings = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3)}
- def convert_bn(blobs, state_dict, caffe_name, torch_name, converted_names):
- # detectron replace bn with affine channel layer
- state_dict[torch_name + '.bias'] = torch.from_numpy(blobs[caffe_name +
- '_b'])
- state_dict[torch_name + '.weight'] = torch.from_numpy(blobs[caffe_name +
- '_s'])
- bn_size = state_dict[torch_name + '.weight'].size()
- state_dict[torch_name + '.running_mean'] = torch.zeros(bn_size)
- state_dict[torch_name + '.running_var'] = torch.ones(bn_size)
- converted_names.add(caffe_name + '_b')
- converted_names.add(caffe_name + '_s')
- def convert_conv_fc(blobs, state_dict, caffe_name, torch_name,
- converted_names):
- state_dict[torch_name + '.weight'] = torch.from_numpy(blobs[caffe_name +
- '_w'])
- converted_names.add(caffe_name + '_w')
- if caffe_name + '_b' in blobs:
- state_dict[torch_name + '.bias'] = torch.from_numpy(blobs[caffe_name +
- '_b'])
- converted_names.add(caffe_name + '_b')
- def convert(src, dst, depth):
- """Convert keys in detectron pretrained ResNet models to pytorch style."""
- # load arch_settings
- if depth not in arch_settings:
- raise ValueError('Only support ResNet-50 and ResNet-101 currently')
- block_nums = arch_settings[depth]
- # load caffe model
- caffe_model = load(src, encoding='latin1')
- blobs = caffe_model['blobs'] if 'blobs' in caffe_model else caffe_model
- # convert to pytorch style
- state_dict = OrderedDict()
- converted_names = set()
- convert_conv_fc(blobs, state_dict, 'conv1', 'conv1', converted_names)
- convert_bn(blobs, state_dict, 'res_conv1_bn', 'bn1', converted_names)
- for i in range(1, len(block_nums) + 1):
- for j in range(block_nums[i - 1]):
- if j == 0:
- convert_conv_fc(blobs, state_dict, f'res{i + 1}_{j}_branch1',
- f'layer{i}.{j}.downsample.0', converted_names)
- convert_bn(blobs, state_dict, f'res{i + 1}_{j}_branch1_bn',
- f'layer{i}.{j}.downsample.1', converted_names)
- for k, letter in enumerate(['a', 'b', 'c']):
- convert_conv_fc(blobs, state_dict,
- f'res{i + 1}_{j}_branch2{letter}',
- f'layer{i}.{j}.conv{k+1}', converted_names)
- convert_bn(blobs, state_dict,
- f'res{i + 1}_{j}_branch2{letter}_bn',
- f'layer{i}.{j}.bn{k + 1}', converted_names)
- # check if all layers are converted
- for key in blobs:
- if key not in converted_names:
- print(f'Not Convert: {key}')
- # save checkpoint
- checkpoint = dict()
- checkpoint['state_dict'] = state_dict
- torch.save(checkpoint, dst)
- def main():
- parser = argparse.ArgumentParser(description='Convert model keys')
- parser.add_argument('src', help='src detectron model path')
- parser.add_argument('dst', help='save path')
- parser.add_argument('depth', type=int, help='ResNet model depth')
- args = parser.parse_args()
- convert(args.src, args.dst, args.depth)
- if __name__ == '__main__':
- main()
|