12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import argparse
- import tempfile
- from collections import OrderedDict
- import torch
- from mmengine import Config
- def parse_config(config_strings):
- temp_file = tempfile.NamedTemporaryFile()
- config_path = f'{temp_file.name}.py'
- with open(config_path, 'w') as f:
- f.write(config_strings)
- config = Config.fromfile(config_path)
- # check whether it is SSD
- if config.model.bbox_head.type != 'SSDHead':
- raise AssertionError('This is not a SSD model.')
- def convert(in_file, out_file):
- checkpoint = torch.load(in_file)
- in_state_dict = checkpoint.pop('state_dict')
- out_state_dict = OrderedDict()
- meta_info = checkpoint['meta']
- parse_config('#' + meta_info['config'])
- for key, value in in_state_dict.items():
- if 'extra' in key:
- layer_idx = int(key.split('.')[2])
- new_key = 'neck.extra_layers.{}.{}.conv.'.format(
- layer_idx // 2, layer_idx % 2) + key.split('.')[-1]
- elif 'l2_norm' in key:
- new_key = 'neck.l2_norm.weight'
- elif 'bbox_head' in key:
- new_key = key[:21] + '.0' + key[21:]
- else:
- new_key = key
- out_state_dict[new_key] = value
- checkpoint['state_dict'] = out_state_dict
- if torch.__version__ >= '1.6':
- torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False)
- else:
- torch.save(checkpoint, out_file)
- def main():
- parser = argparse.ArgumentParser(description='Upgrade SSD version')
- parser.add_argument('in_file', help='input checkpoint file')
- parser.add_argument('out_file', help='output checkpoint file')
- args = parser.parse_args()
- convert(args.in_file, args.out_file)
- if __name__ == '__main__':
- main()
|