123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626 |
- import argparse
- import numpy as np
- import torch
- from tensorflow.python.training import py_checkpoint_reader
- torch.set_printoptions(precision=20)
- def tf2pth(v):
- if v.ndim == 4:
- return np.ascontiguousarray(v.transpose(3, 2, 0, 1))
- elif v.ndim == 2:
- return np.ascontiguousarray(v.transpose())
- return v
- def convert_key(model_name, bifpn_repeats, weights):
- p6_w1 = [
- torch.tensor([-1e4, -1e4], dtype=torch.float64)
- for _ in range(bifpn_repeats)
- ]
- p5_w1 = [
- torch.tensor([-1e4, -1e4], dtype=torch.float64)
- for _ in range(bifpn_repeats)
- ]
- p4_w1 = [
- torch.tensor([-1e4, -1e4], dtype=torch.float64)
- for _ in range(bifpn_repeats)
- ]
- p3_w1 = [
- torch.tensor([-1e4, -1e4], dtype=torch.float64)
- for _ in range(bifpn_repeats)
- ]
- p4_w2 = [
- torch.tensor([-1e4, -1e4, -1e4], dtype=torch.float64)
- for _ in range(bifpn_repeats)
- ]
- p5_w2 = [
- torch.tensor([-1e4, -1e4, -1e4], dtype=torch.float64)
- for _ in range(bifpn_repeats)
- ]
- p6_w2 = [
- torch.tensor([-1e4, -1e4, -1e4], dtype=torch.float64)
- for _ in range(bifpn_repeats)
- ]
- p7_w2 = [
- torch.tensor([-1e4, -1e4], dtype=torch.float64)
- for _ in range(bifpn_repeats)
- ]
- idx2key = {
- 0: '1.0',
- 1: '2.0',
- 2: '2.1',
- 3: '3.0',
- 4: '3.1',
- 5: '4.0',
- 6: '4.1',
- 7: '4.2',
- 8: '4.3',
- 9: '4.4',
- 10: '4.5',
- 11: '5.0',
- 12: '5.1',
- 13: '5.2',
- 14: '5.3',
- 15: '5.4'
- }
- m = dict()
- for k, v in weights.items():
- if 'Exponential' in k or 'global_step' in k:
- continue
- seg = k.split('/')
- if len(seg) == 1:
- continue
- if seg[2] == 'depthwise_conv2d':
- v = v.transpose(1, 0)
- if seg[0] == model_name:
- if seg[1] == 'stem':
- prefix = 'backbone.layers.0'
- mapping = {
- 'conv2d/kernel': 'conv.weight',
- 'tpu_batch_normalization/beta': 'bn.bias',
- 'tpu_batch_normalization/gamma': 'bn.weight',
- 'tpu_batch_normalization/moving_mean': 'bn.running_mean',
- 'tpu_batch_normalization/moving_variance':
- 'bn.running_var',
- }
- suffix = mapping['/'.join(seg[2:])]
- m[prefix + '.' + suffix] = v
- elif seg[1].startswith('blocks_'):
- idx = int(seg[1][7:])
- prefix = '.'.join(['backbone', 'layers', idx2key[idx]])
- base_mapping = {
- 'depthwise_conv2d/depthwise_kernel':
- 'depthwise_conv.conv.weight',
- 'se/conv2d/kernel': 'se.conv1.conv.weight',
- 'se/conv2d/bias': 'se.conv1.conv.bias',
- 'se/conv2d_1/kernel': 'se.conv2.conv.weight',
- 'se/conv2d_1/bias': 'se.conv2.conv.bias'
- }
- if idx == 0:
- mapping = {
- 'conv2d/kernel':
- 'linear_conv.conv.weight',
- 'tpu_batch_normalization/beta':
- 'depthwise_conv.bn.bias',
- 'tpu_batch_normalization/gamma':
- 'depthwise_conv.bn.weight',
- 'tpu_batch_normalization/moving_mean':
- 'depthwise_conv.bn.running_mean',
- 'tpu_batch_normalization/moving_variance':
- 'depthwise_conv.bn.running_var',
- 'tpu_batch_normalization_1/beta':
- 'linear_conv.bn.bias',
- 'tpu_batch_normalization_1/gamma':
- 'linear_conv.bn.weight',
- 'tpu_batch_normalization_1/moving_mean':
- 'linear_conv.bn.running_mean',
- 'tpu_batch_normalization_1/moving_variance':
- 'linear_conv.bn.running_var',
- }
- else:
- mapping = {
- 'depthwise_conv2d/depthwise_kernel':
- 'depthwise_conv.conv.weight',
- 'conv2d/kernel':
- 'expand_conv.conv.weight',
- 'conv2d_1/kernel':
- 'linear_conv.conv.weight',
- 'tpu_batch_normalization/beta':
- 'expand_conv.bn.bias',
- 'tpu_batch_normalization/gamma':
- 'expand_conv.bn.weight',
- 'tpu_batch_normalization/moving_mean':
- 'expand_conv.bn.running_mean',
- 'tpu_batch_normalization/moving_variance':
- 'expand_conv.bn.running_var',
- 'tpu_batch_normalization_1/beta':
- 'depthwise_conv.bn.bias',
- 'tpu_batch_normalization_1/gamma':
- 'depthwise_conv.bn.weight',
- 'tpu_batch_normalization_1/moving_mean':
- 'depthwise_conv.bn.running_mean',
- 'tpu_batch_normalization_1/moving_variance':
- 'depthwise_conv.bn.running_var',
- 'tpu_batch_normalization_2/beta':
- 'linear_conv.bn.bias',
- 'tpu_batch_normalization_2/gamma':
- 'linear_conv.bn.weight',
- 'tpu_batch_normalization_2/moving_mean':
- 'linear_conv.bn.running_mean',
- 'tpu_batch_normalization_2/moving_variance':
- 'linear_conv.bn.running_var',
- }
- mapping.update(base_mapping)
- suffix = mapping['/'.join(seg[2:])]
- m[prefix + '.' + suffix] = v
- elif seg[0] == 'resample_p6':
- prefix = 'neck.bifpn.0.p5_to_p6.0'
- mapping = {
- 'conv2d/kernel': 'down_conv.weight',
- 'conv2d/bias': 'down_conv.bias',
- 'bn/beta': 'bn.bias',
- 'bn/gamma': 'bn.weight',
- 'bn/moving_mean': 'bn.running_mean',
- 'bn/moving_variance': 'bn.running_var',
- }
- suffix = mapping['/'.join(seg[1:])]
- m[prefix + '.' + suffix] = v
- elif seg[0] == 'fpn_cells':
- fpn_idx = int(seg[1][5:])
- prefix = '.'.join(['neck', 'bifpn', str(fpn_idx)])
- fnode_id = int(seg[2][5])
- if fnode_id == 0:
- mapping = {
- 'op_after_combine5/conv/depthwise_kernel':
- 'conv6_up.depthwise_conv.weight',
- 'op_after_combine5/conv/pointwise_kernel':
- 'conv6_up.pointwise_conv.weight',
- 'op_after_combine5/conv/bias':
- 'conv6_up.pointwise_conv.bias',
- 'op_after_combine5/bn/beta':
- 'conv6_up.bn.bias',
- 'op_after_combine5/bn/gamma':
- 'conv6_up.bn.weight',
- 'op_after_combine5/bn/moving_mean':
- 'conv6_up.bn.running_mean',
- 'op_after_combine5/bn/moving_variance':
- 'conv6_up.bn.running_var',
- }
- if seg[3] != 'WSM' and seg[3] != 'WSM_1':
- suffix = mapping['/'.join(seg[3:])]
- if 'depthwise_conv' in suffix:
- v = v.transpose(1, 0)
- m[prefix + '.' + suffix] = v
- elif seg[3] == 'WSM':
- p6_w1[fpn_idx][0] = v
- elif seg[3] == 'WSM_1':
- p6_w1[fpn_idx][1] = v
- if torch.min(p6_w1[fpn_idx]) > -1e4:
- m[prefix + '.p6_w1'] = p6_w1[fpn_idx]
- elif fnode_id == 1:
- base_mapping = {
- 'op_after_combine6/conv/depthwise_kernel':
- 'conv5_up.depthwise_conv.weight',
- 'op_after_combine6/conv/pointwise_kernel':
- 'conv5_up.pointwise_conv.weight',
- 'op_after_combine6/conv/bias':
- 'conv5_up.pointwise_conv.bias',
- 'op_after_combine6/bn/beta':
- 'conv5_up.bn.bias',
- 'op_after_combine6/bn/gamma':
- 'conv5_up.bn.weight',
- 'op_after_combine6/bn/moving_mean':
- 'conv5_up.bn.running_mean',
- 'op_after_combine6/bn/moving_variance':
- 'conv5_up.bn.running_var',
- }
- if fpn_idx == 0:
- mapping = {
- 'resample_0_2_6/conv2d/kernel':
- 'p5_down_channel.down_conv.weight',
- 'resample_0_2_6/conv2d/bias':
- 'p5_down_channel.down_conv.bias',
- 'resample_0_2_6/bn/beta':
- 'p5_down_channel.bn.bias',
- 'resample_0_2_6/bn/gamma':
- 'p5_down_channel.bn.weight',
- 'resample_0_2_6/bn/moving_mean':
- 'p5_down_channel.bn.running_mean',
- 'resample_0_2_6/bn/moving_variance':
- 'p5_down_channel.bn.running_var',
- }
- base_mapping.update(mapping)
- if seg[3] != 'WSM' and seg[3] != 'WSM_1':
- suffix = base_mapping['/'.join(seg[3:])]
- if 'depthwise_conv' in suffix:
- v = v.transpose(1, 0)
- m[prefix + '.' + suffix] = v
- elif seg[3] == 'WSM':
- p5_w1[fpn_idx][0] = v
- elif seg[3] == 'WSM_1':
- p5_w1[fpn_idx][1] = v
- if torch.min(p5_w1[fpn_idx]) > -1e4:
- m[prefix + '.p5_w1'] = p5_w1[fpn_idx]
- elif fnode_id == 2:
- base_mapping = {
- 'op_after_combine7/conv/depthwise_kernel':
- 'conv4_up.depthwise_conv.weight',
- 'op_after_combine7/conv/pointwise_kernel':
- 'conv4_up.pointwise_conv.weight',
- 'op_after_combine7/conv/bias':
- 'conv4_up.pointwise_conv.bias',
- 'op_after_combine7/bn/beta':
- 'conv4_up.bn.bias',
- 'op_after_combine7/bn/gamma':
- 'conv4_up.bn.weight',
- 'op_after_combine7/bn/moving_mean':
- 'conv4_up.bn.running_mean',
- 'op_after_combine7/bn/moving_variance':
- 'conv4_up.bn.running_var',
- }
- if fpn_idx == 0:
- mapping = {
- 'resample_0_1_7/conv2d/kernel':
- 'p4_down_channel.down_conv.weight',
- 'resample_0_1_7/conv2d/bias':
- 'p4_down_channel.down_conv.bias',
- 'resample_0_1_7/bn/beta':
- 'p4_down_channel.bn.bias',
- 'resample_0_1_7/bn/gamma':
- 'p4_down_channel.bn.weight',
- 'resample_0_1_7/bn/moving_mean':
- 'p4_down_channel.bn.running_mean',
- 'resample_0_1_7/bn/moving_variance':
- 'p4_down_channel.bn.running_var',
- }
- base_mapping.update(mapping)
- if seg[3] != 'WSM' and seg[3] != 'WSM_1':
- suffix = base_mapping['/'.join(seg[3:])]
- if 'depthwise_conv' in suffix:
- v = v.transpose(1, 0)
- m[prefix + '.' + suffix] = v
- elif seg[3] == 'WSM':
- p4_w1[fpn_idx][0] = v
- elif seg[3] == 'WSM_1':
- p4_w1[fpn_idx][1] = v
- if torch.min(p4_w1[fpn_idx]) > -1e4:
- m[prefix + '.p4_w1'] = p4_w1[fpn_idx]
- elif fnode_id == 3:
- base_mapping = {
- 'op_after_combine8/conv/depthwise_kernel':
- 'conv3_up.depthwise_conv.weight',
- 'op_after_combine8/conv/pointwise_kernel':
- 'conv3_up.pointwise_conv.weight',
- 'op_after_combine8/conv/bias':
- 'conv3_up.pointwise_conv.bias',
- 'op_after_combine8/bn/beta':
- 'conv3_up.bn.bias',
- 'op_after_combine8/bn/gamma':
- 'conv3_up.bn.weight',
- 'op_after_combine8/bn/moving_mean':
- 'conv3_up.bn.running_mean',
- 'op_after_combine8/bn/moving_variance':
- 'conv3_up.bn.running_var',
- }
- if fpn_idx == 0:
- mapping = {
- 'resample_0_0_8/conv2d/kernel':
- 'p3_down_channel.down_conv.weight',
- 'resample_0_0_8/conv2d/bias':
- 'p3_down_channel.down_conv.bias',
- 'resample_0_0_8/bn/beta':
- 'p3_down_channel.bn.bias',
- 'resample_0_0_8/bn/gamma':
- 'p3_down_channel.bn.weight',
- 'resample_0_0_8/bn/moving_mean':
- 'p3_down_channel.bn.running_mean',
- 'resample_0_0_8/bn/moving_variance':
- 'p3_down_channel.bn.running_var',
- }
- base_mapping.update(mapping)
- if seg[3] != 'WSM' and seg[3] != 'WSM_1':
- suffix = base_mapping['/'.join(seg[3:])]
- if 'depthwise_conv' in suffix:
- v = v.transpose(1, 0)
- m[prefix + '.' + suffix] = v
- elif seg[3] == 'WSM':
- p3_w1[fpn_idx][0] = v
- elif seg[3] == 'WSM_1':
- p3_w1[fpn_idx][1] = v
- if torch.min(p3_w1[fpn_idx]) > -1e4:
- m[prefix + '.p3_w1'] = p3_w1[fpn_idx]
- elif fnode_id == 4:
- base_mapping = {
- 'op_after_combine9/conv/depthwise_kernel':
- 'conv4_down.depthwise_conv.weight',
- 'op_after_combine9/conv/pointwise_kernel':
- 'conv4_down.pointwise_conv.weight',
- 'op_after_combine9/conv/bias':
- 'conv4_down.pointwise_conv.bias',
- 'op_after_combine9/bn/beta':
- 'conv4_down.bn.bias',
- 'op_after_combine9/bn/gamma':
- 'conv4_down.bn.weight',
- 'op_after_combine9/bn/moving_mean':
- 'conv4_down.bn.running_mean',
- 'op_after_combine9/bn/moving_variance':
- 'conv4_down.bn.running_var',
- }
- if fpn_idx == 0:
- mapping = {
- 'resample_0_1_9/conv2d/kernel':
- 'p4_level_connection.down_conv.weight',
- 'resample_0_1_9/conv2d/bias':
- 'p4_level_connection.down_conv.bias',
- 'resample_0_1_9/bn/beta':
- 'p4_level_connection.bn.bias',
- 'resample_0_1_9/bn/gamma':
- 'p4_level_connection.bn.weight',
- 'resample_0_1_9/bn/moving_mean':
- 'p4_level_connection.bn.running_mean',
- 'resample_0_1_9/bn/moving_variance':
- 'p4_level_connection.bn.running_var',
- }
- base_mapping.update(mapping)
- if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2':
- suffix = base_mapping['/'.join(seg[3:])]
- if 'depthwise_conv' in suffix:
- v = v.transpose(1, 0)
- m[prefix + '.' + suffix] = v
- elif seg[3] == 'WSM':
- p4_w2[fpn_idx][0] = v
- elif seg[3] == 'WSM_1':
- p4_w2[fpn_idx][1] = v
- elif seg[3] == 'WSM_2':
- p4_w2[fpn_idx][2] = v
- if torch.min(p4_w2[fpn_idx]) > -1e4:
- m[prefix + '.p4_w2'] = p4_w2[fpn_idx]
- elif fnode_id == 5:
- base_mapping = {
- 'op_after_combine10/conv/depthwise_kernel':
- 'conv5_down.depthwise_conv.weight',
- 'op_after_combine10/conv/pointwise_kernel':
- 'conv5_down.pointwise_conv.weight',
- 'op_after_combine10/conv/bias':
- 'conv5_down.pointwise_conv.bias',
- 'op_after_combine10/bn/beta':
- 'conv5_down.bn.bias',
- 'op_after_combine10/bn/gamma':
- 'conv5_down.bn.weight',
- 'op_after_combine10/bn/moving_mean':
- 'conv5_down.bn.running_mean',
- 'op_after_combine10/bn/moving_variance':
- 'conv5_down.bn.running_var',
- }
- if fpn_idx == 0:
- mapping = {
- 'resample_0_2_10/conv2d/kernel':
- 'p5_level_connection.down_conv.weight',
- 'resample_0_2_10/conv2d/bias':
- 'p5_level_connection.down_conv.bias',
- 'resample_0_2_10/bn/beta':
- 'p5_level_connection.bn.bias',
- 'resample_0_2_10/bn/gamma':
- 'p5_level_connection.bn.weight',
- 'resample_0_2_10/bn/moving_mean':
- 'p5_level_connection.bn.running_mean',
- 'resample_0_2_10/bn/moving_variance':
- 'p5_level_connection.bn.running_var',
- }
- base_mapping.update(mapping)
- if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2':
- suffix = base_mapping['/'.join(seg[3:])]
- if 'depthwise_conv' in suffix:
- v = v.transpose(1, 0)
- m[prefix + '.' + suffix] = v
- elif seg[3] == 'WSM':
- p5_w2[fpn_idx][0] = v
- elif seg[3] == 'WSM_1':
- p5_w2[fpn_idx][1] = v
- elif seg[3] == 'WSM_2':
- p5_w2[fpn_idx][2] = v
- if torch.min(p5_w2[fpn_idx]) > -1e4:
- m[prefix + '.p5_w2'] = p5_w2[fpn_idx]
- elif fnode_id == 6:
- base_mapping = {
- 'op_after_combine11/conv/depthwise_kernel':
- 'conv6_down.depthwise_conv.weight',
- 'op_after_combine11/conv/pointwise_kernel':
- 'conv6_down.pointwise_conv.weight',
- 'op_after_combine11/conv/bias':
- 'conv6_down.pointwise_conv.bias',
- 'op_after_combine11/bn/beta':
- 'conv6_down.bn.bias',
- 'op_after_combine11/bn/gamma':
- 'conv6_down.bn.weight',
- 'op_after_combine11/bn/moving_mean':
- 'conv6_down.bn.running_mean',
- 'op_after_combine11/bn/moving_variance':
- 'conv6_down.bn.running_var',
- }
- if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2':
- suffix = base_mapping['/'.join(seg[3:])]
- if 'depthwise_conv' in suffix:
- v = v.transpose(1, 0)
- m[prefix + '.' + suffix] = v
- elif seg[3] == 'WSM':
- p6_w2[fpn_idx][0] = v
- elif seg[3] == 'WSM_1':
- p6_w2[fpn_idx][1] = v
- elif seg[3] == 'WSM_2':
- p6_w2[fpn_idx][2] = v
- if torch.min(p6_w2[fpn_idx]) > -1e4:
- m[prefix + '.p6_w2'] = p6_w2[fpn_idx]
- elif fnode_id == 7:
- base_mapping = {
- 'op_after_combine12/conv/depthwise_kernel':
- 'conv7_down.depthwise_conv.weight',
- 'op_after_combine12/conv/pointwise_kernel':
- 'conv7_down.pointwise_conv.weight',
- 'op_after_combine12/conv/bias':
- 'conv7_down.pointwise_conv.bias',
- 'op_after_combine12/bn/beta':
- 'conv7_down.bn.bias',
- 'op_after_combine12/bn/gamma':
- 'conv7_down.bn.weight',
- 'op_after_combine12/bn/moving_mean':
- 'conv7_down.bn.running_mean',
- 'op_after_combine12/bn/moving_variance':
- 'conv7_down.bn.running_var',
- }
- if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2':
- suffix = base_mapping['/'.join(seg[3:])]
- if 'depthwise_conv' in suffix:
- v = v.transpose(1, 0)
- m[prefix + '.' + suffix] = v
- elif seg[3] == 'WSM':
- p7_w2[fpn_idx][0] = v
- elif seg[3] == 'WSM_1':
- p7_w2[fpn_idx][1] = v
- if torch.min(p7_w2[fpn_idx]) > -1e4:
- m[prefix + '.p7_w2'] = p7_w2[fpn_idx]
- elif seg[0] == 'box_net':
- if 'box-predict' in seg[1]:
- prefix = '.'.join(['bbox_head', 'reg_header'])
- base_mapping = {
- 'depthwise_kernel': 'depthwise_conv.weight',
- 'pointwise_kernel': 'pointwise_conv.weight',
- 'bias': 'pointwise_conv.bias'
- }
- suffix = base_mapping['/'.join(seg[2:])]
- if 'depthwise_conv' in suffix:
- v = v.transpose(1, 0)
- m[prefix + '.' + suffix] = v
- elif 'bn' in seg[1]:
- bbox_conv_idx = int(seg[1][4])
- bbox_bn_idx = int(seg[1][9]) - 3
- prefix = '.'.join([
- 'bbox_head', 'reg_bn_list',
- str(bbox_conv_idx),
- str(bbox_bn_idx)
- ])
- base_mapping = {
- 'beta': 'bias',
- 'gamma': 'weight',
- 'moving_mean': 'running_mean',
- 'moving_variance': 'running_var'
- }
- suffix = base_mapping['/'.join(seg[2:])]
- m[prefix + '.' + suffix] = v
- else:
- bbox_conv_idx = int(seg[1][4])
- prefix = '.'.join(
- ['bbox_head', 'reg_conv_list',
- str(bbox_conv_idx)])
- base_mapping = {
- 'depthwise_kernel': 'depthwise_conv.weight',
- 'pointwise_kernel': 'pointwise_conv.weight',
- 'bias': 'pointwise_conv.bias'
- }
- suffix = base_mapping['/'.join(seg[2:])]
- if 'depthwise_conv' in suffix:
- v = v.transpose(1, 0)
- m[prefix + '.' + suffix] = v
- elif seg[0] == 'class_net':
- if 'class-predict' in seg[1]:
- prefix = '.'.join(['bbox_head', 'cls_header'])
- base_mapping = {
- 'depthwise_kernel': 'depthwise_conv.weight',
- 'pointwise_kernel': 'pointwise_conv.weight',
- 'bias': 'pointwise_conv.bias'
- }
- suffix = base_mapping['/'.join(seg[2:])]
- if 'depthwise_conv' in suffix:
- v = v.transpose(1, 0)
- m[prefix + '.' + suffix] = v
- elif 'bn' in seg[1]:
- cls_conv_idx = int(seg[1][6])
- cls_bn_idx = int(seg[1][11]) - 3
- prefix = '.'.join([
- 'bbox_head', 'cls_bn_list',
- str(cls_conv_idx),
- str(cls_bn_idx)
- ])
- base_mapping = {
- 'beta': 'bias',
- 'gamma': 'weight',
- 'moving_mean': 'running_mean',
- 'moving_variance': 'running_var'
- }
- suffix = base_mapping['/'.join(seg[2:])]
- m[prefix + '.' + suffix] = v
- else:
- cls_conv_idx = int(seg[1][6])
- prefix = '.'.join(
- ['bbox_head', 'cls_conv_list',
- str(cls_conv_idx)])
- base_mapping = {
- 'depthwise_kernel': 'depthwise_conv.weight',
- 'pointwise_kernel': 'pointwise_conv.weight',
- 'bias': 'pointwise_conv.bias'
- }
- suffix = base_mapping['/'.join(seg[2:])]
- if 'depthwise_conv' in suffix:
- v = v.transpose(1, 0)
- m[prefix + '.' + suffix] = v
- return m
- def parse_args():
- parser = argparse.ArgumentParser(
- description='convert efficientdet weight from tensorflow to pytorch')
- parser.add_argument(
- '--backbone',
- type=str,
- help='efficientnet model name, like efficientnet-b0')
- parser.add_argument(
- '--tensorflow_weight',
- type=str,
- help='efficientdet tensorflow weight name, like efficientdet-d0/model')
- parser.add_argument(
- '--out_weight',
- type=str,
- help='efficientdet pytorch weight name like demo.pth')
- args = parser.parse_args()
- return args
- def main():
- args = parse_args()
- model_name = args.backbone
- ori_weight_name = args.tensorflow_weight
- out_name = args.out_weight
- repeat_map = {
- 0: 3,
- 1: 4,
- 2: 5,
- 3: 6,
- 4: 7,
- 5: 7,
- 6: 8,
- 7: 8,
- }
- reader = py_checkpoint_reader.NewCheckpointReader(ori_weight_name)
- weights = {
- n: torch.as_tensor(tf2pth(reader.get_tensor(n)))
- for (n, _) in reader.get_variable_to_shape_map().items()
- }
- bifpn_repeats = repeat_map[int(model_name[14])]
- out = convert_key(model_name, bifpn_repeats, weights)
- result = {'state_dict': out}
- torch.save(result, out_name)
- if __name__ == '__main__':
- main()
|