# Copyright (c) OpenMMLab. All rights reserved. import argparse from collections import OrderedDict import numpy as np import torch from mmengine.fileio import load def convert(src, dst): if src.endswith('pth'): src_model = torch.load(src) else: src_model = load(src) dst_state_dict = OrderedDict() for k, v in src_model['model'].items(): key_name_split = k.split('.') if 'backbone.fpn_lateral' in k: lateral_id = int(key_name_split[-2][-1]) name = f'neck.lateral_convs.{lateral_id - 2}.' \ f'conv.{key_name_split[-1]}' elif 'backbone.fpn_output' in k: lateral_id = int(key_name_split[-2][-1]) name = f'neck.fpn_convs.{lateral_id - 2}.conv.' \ f'{key_name_split[-1]}' elif 'backbone.bottom_up.stem.conv1.norm.' in k: name = f'backbone.bn1.{key_name_split[-1]}' elif 'backbone.bottom_up.stem.conv1.' in k: name = f'backbone.conv1.{key_name_split[-1]}' elif 'backbone.bottom_up.res' in k: # weight_type = key_name_split[-1] res_id = int(key_name_split[2][-1]) - 1 # deal with short cut if 'shortcut' in key_name_split[4]: if 'shortcut' == key_name_split[-2]: name = f'backbone.layer{res_id}.' \ f'{key_name_split[3]}.downsample.0.' \ f'{key_name_split[-1]}' elif 'shortcut' == key_name_split[-3]: name = f'backbone.layer{res_id}.' \ f'{key_name_split[3]}.downsample.1.' \ f'{key_name_split[-1]}' else: print(f'Unvalid key {k}') # deal with conv elif 'conv' in key_name_split[-2]: conv_id = int(key_name_split[-2][-1]) name = f'backbone.layer{res_id}.{key_name_split[3]}' \ f'.conv{conv_id}.{key_name_split[-1]}' # deal with BN elif key_name_split[-2] == 'norm': conv_id = int(key_name_split[-3][-1]) name = f'backbone.layer{res_id}.{key_name_split[3]}.' \ f'bn{conv_id}.{key_name_split[-1]}' else: print(f'{k} is invalid') elif key_name_split[0] == 'head': # d2: head.xxx -> mmdet: bbox_head.xxx name = f'bbox_{k}' else: # some base parameters such as beta will not convert print(f'{k} is not converted!!') continue if not isinstance(v, np.ndarray) and not isinstance(v, torch.Tensor): raise ValueError( 'Unsupported type found in checkpoint! {}: {}'.format( k, type(v))) if not isinstance(v, torch.Tensor): dst_state_dict[name] = torch.from_numpy(v) else: dst_state_dict[name] = v mmdet_model = dict(state_dict=dst_state_dict, meta=dict()) torch.save(mmdet_model, dst) def main(): parser = argparse.ArgumentParser(description='Convert model keys') parser.add_argument('src', help='src detectron model path') parser.add_argument('dst', help='save path') args = parser.parse_args() convert(args.src, args.dst) if __name__ == '__main__': main()