12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import argparse
- from collections import OrderedDict
- import numpy as np
- import torch
- from mmengine.fileio import load
- def convert(src, dst):
- if src.endswith('pth'):
- src_model = torch.load(src)
- else:
- src_model = load(src)
- dst_state_dict = OrderedDict()
- for k, v in src_model['model'].items():
- key_name_split = k.split('.')
- if 'backbone.fpn_lateral' in k:
- lateral_id = int(key_name_split[-2][-1])
- name = f'neck.lateral_convs.{lateral_id - 2}.' \
- f'conv.{key_name_split[-1]}'
- elif 'backbone.fpn_output' in k:
- lateral_id = int(key_name_split[-2][-1])
- name = f'neck.fpn_convs.{lateral_id - 2}.conv.' \
- f'{key_name_split[-1]}'
- elif 'backbone.bottom_up.stem.conv1.norm.' in k:
- name = f'backbone.bn1.{key_name_split[-1]}'
- elif 'backbone.bottom_up.stem.conv1.' in k:
- name = f'backbone.conv1.{key_name_split[-1]}'
- elif 'backbone.bottom_up.res' in k:
- # weight_type = key_name_split[-1]
- res_id = int(key_name_split[2][-1]) - 1
- # deal with short cut
- if 'shortcut' in key_name_split[4]:
- if 'shortcut' == key_name_split[-2]:
- name = f'backbone.layer{res_id}.' \
- f'{key_name_split[3]}.downsample.0.' \
- f'{key_name_split[-1]}'
- elif 'shortcut' == key_name_split[-3]:
- name = f'backbone.layer{res_id}.' \
- f'{key_name_split[3]}.downsample.1.' \
- f'{key_name_split[-1]}'
- else:
- print(f'Unvalid key {k}')
- # deal with conv
- elif 'conv' in key_name_split[-2]:
- conv_id = int(key_name_split[-2][-1])
- name = f'backbone.layer{res_id}.{key_name_split[3]}' \
- f'.conv{conv_id}.{key_name_split[-1]}'
- # deal with BN
- elif key_name_split[-2] == 'norm':
- conv_id = int(key_name_split[-3][-1])
- name = f'backbone.layer{res_id}.{key_name_split[3]}.' \
- f'bn{conv_id}.{key_name_split[-1]}'
- else:
- print(f'{k} is invalid')
- elif key_name_split[0] == 'head':
- # d2: head.xxx -> mmdet: bbox_head.xxx
- name = f'bbox_{k}'
- else:
- # some base parameters such as beta will not convert
- print(f'{k} is not converted!!')
- continue
- if not isinstance(v, np.ndarray) and not isinstance(v, torch.Tensor):
- raise ValueError(
- 'Unsupported type found in checkpoint! {}: {}'.format(
- k, type(v)))
- if not isinstance(v, torch.Tensor):
- dst_state_dict[name] = torch.from_numpy(v)
- else:
- dst_state_dict[name] = v
- mmdet_model = dict(state_dict=dst_state_dict, meta=dict())
- torch.save(mmdet_model, dst)
- def main():
- parser = argparse.ArgumentParser(description='Convert model keys')
- parser.add_argument('src', help='src detectron model path')
- parser.add_argument('dst', help='save path')
- args = parser.parse_args()
- convert(args.src, args.dst)
- if __name__ == '__main__':
- main()
|