import argparse import numpy as np import torch from tensorflow.python.training import py_checkpoint_reader torch.set_printoptions(precision=20) def tf2pth(v): if v.ndim == 4: return np.ascontiguousarray(v.transpose(3, 2, 0, 1)) elif v.ndim == 2: return np.ascontiguousarray(v.transpose()) return v def convert_key(model_name, bifpn_repeats, weights): p6_w1 = [ torch.tensor([-1e4, -1e4], dtype=torch.float64) for _ in range(bifpn_repeats) ] p5_w1 = [ torch.tensor([-1e4, -1e4], dtype=torch.float64) for _ in range(bifpn_repeats) ] p4_w1 = [ torch.tensor([-1e4, -1e4], dtype=torch.float64) for _ in range(bifpn_repeats) ] p3_w1 = [ torch.tensor([-1e4, -1e4], dtype=torch.float64) for _ in range(bifpn_repeats) ] p4_w2 = [ torch.tensor([-1e4, -1e4, -1e4], dtype=torch.float64) for _ in range(bifpn_repeats) ] p5_w2 = [ torch.tensor([-1e4, -1e4, -1e4], dtype=torch.float64) for _ in range(bifpn_repeats) ] p6_w2 = [ torch.tensor([-1e4, -1e4, -1e4], dtype=torch.float64) for _ in range(bifpn_repeats) ] p7_w2 = [ torch.tensor([-1e4, -1e4], dtype=torch.float64) for _ in range(bifpn_repeats) ] idx2key = { 0: '1.0', 1: '2.0', 2: '2.1', 3: '3.0', 4: '3.1', 5: '4.0', 6: '4.1', 7: '4.2', 8: '4.3', 9: '4.4', 10: '4.5', 11: '5.0', 12: '5.1', 13: '5.2', 14: '5.3', 15: '5.4' } m = dict() for k, v in weights.items(): if 'Exponential' in k or 'global_step' in k: continue seg = k.split('/') if len(seg) == 1: continue if seg[2] == 'depthwise_conv2d': v = v.transpose(1, 0) if seg[0] == model_name: if seg[1] == 'stem': prefix = 'backbone.layers.0' mapping = { 'conv2d/kernel': 'conv.weight', 'tpu_batch_normalization/beta': 'bn.bias', 'tpu_batch_normalization/gamma': 'bn.weight', 'tpu_batch_normalization/moving_mean': 'bn.running_mean', 'tpu_batch_normalization/moving_variance': 'bn.running_var', } suffix = mapping['/'.join(seg[2:])] m[prefix + '.' + suffix] = v elif seg[1].startswith('blocks_'): idx = int(seg[1][7:]) prefix = '.'.join(['backbone', 'layers', idx2key[idx]]) base_mapping = { 'depthwise_conv2d/depthwise_kernel': 'depthwise_conv.conv.weight', 'se/conv2d/kernel': 'se.conv1.conv.weight', 'se/conv2d/bias': 'se.conv1.conv.bias', 'se/conv2d_1/kernel': 'se.conv2.conv.weight', 'se/conv2d_1/bias': 'se.conv2.conv.bias' } if idx == 0: mapping = { 'conv2d/kernel': 'linear_conv.conv.weight', 'tpu_batch_normalization/beta': 'depthwise_conv.bn.bias', 'tpu_batch_normalization/gamma': 'depthwise_conv.bn.weight', 'tpu_batch_normalization/moving_mean': 'depthwise_conv.bn.running_mean', 'tpu_batch_normalization/moving_variance': 'depthwise_conv.bn.running_var', 'tpu_batch_normalization_1/beta': 'linear_conv.bn.bias', 'tpu_batch_normalization_1/gamma': 'linear_conv.bn.weight', 'tpu_batch_normalization_1/moving_mean': 'linear_conv.bn.running_mean', 'tpu_batch_normalization_1/moving_variance': 'linear_conv.bn.running_var', } else: mapping = { 'depthwise_conv2d/depthwise_kernel': 'depthwise_conv.conv.weight', 'conv2d/kernel': 'expand_conv.conv.weight', 'conv2d_1/kernel': 'linear_conv.conv.weight', 'tpu_batch_normalization/beta': 'expand_conv.bn.bias', 'tpu_batch_normalization/gamma': 'expand_conv.bn.weight', 'tpu_batch_normalization/moving_mean': 'expand_conv.bn.running_mean', 'tpu_batch_normalization/moving_variance': 'expand_conv.bn.running_var', 'tpu_batch_normalization_1/beta': 'depthwise_conv.bn.bias', 'tpu_batch_normalization_1/gamma': 'depthwise_conv.bn.weight', 'tpu_batch_normalization_1/moving_mean': 'depthwise_conv.bn.running_mean', 'tpu_batch_normalization_1/moving_variance': 'depthwise_conv.bn.running_var', 'tpu_batch_normalization_2/beta': 'linear_conv.bn.bias', 'tpu_batch_normalization_2/gamma': 'linear_conv.bn.weight', 'tpu_batch_normalization_2/moving_mean': 'linear_conv.bn.running_mean', 'tpu_batch_normalization_2/moving_variance': 'linear_conv.bn.running_var', } mapping.update(base_mapping) suffix = mapping['/'.join(seg[2:])] m[prefix + '.' + suffix] = v elif seg[0] == 'resample_p6': prefix = 'neck.bifpn.0.p5_to_p6.0' mapping = { 'conv2d/kernel': 'down_conv.weight', 'conv2d/bias': 'down_conv.bias', 'bn/beta': 'bn.bias', 'bn/gamma': 'bn.weight', 'bn/moving_mean': 'bn.running_mean', 'bn/moving_variance': 'bn.running_var', } suffix = mapping['/'.join(seg[1:])] m[prefix + '.' + suffix] = v elif seg[0] == 'fpn_cells': fpn_idx = int(seg[1][5:]) prefix = '.'.join(['neck', 'bifpn', str(fpn_idx)]) fnode_id = int(seg[2][5]) if fnode_id == 0: mapping = { 'op_after_combine5/conv/depthwise_kernel': 'conv6_up.depthwise_conv.weight', 'op_after_combine5/conv/pointwise_kernel': 'conv6_up.pointwise_conv.weight', 'op_after_combine5/conv/bias': 'conv6_up.pointwise_conv.bias', 'op_after_combine5/bn/beta': 'conv6_up.bn.bias', 'op_after_combine5/bn/gamma': 'conv6_up.bn.weight', 'op_after_combine5/bn/moving_mean': 'conv6_up.bn.running_mean', 'op_after_combine5/bn/moving_variance': 'conv6_up.bn.running_var', } if seg[3] != 'WSM' and seg[3] != 'WSM_1': suffix = mapping['/'.join(seg[3:])] if 'depthwise_conv' in suffix: v = v.transpose(1, 0) m[prefix + '.' + suffix] = v elif seg[3] == 'WSM': p6_w1[fpn_idx][0] = v elif seg[3] == 'WSM_1': p6_w1[fpn_idx][1] = v if torch.min(p6_w1[fpn_idx]) > -1e4: m[prefix + '.p6_w1'] = p6_w1[fpn_idx] elif fnode_id == 1: base_mapping = { 'op_after_combine6/conv/depthwise_kernel': 'conv5_up.depthwise_conv.weight', 'op_after_combine6/conv/pointwise_kernel': 'conv5_up.pointwise_conv.weight', 'op_after_combine6/conv/bias': 'conv5_up.pointwise_conv.bias', 'op_after_combine6/bn/beta': 'conv5_up.bn.bias', 'op_after_combine6/bn/gamma': 'conv5_up.bn.weight', 'op_after_combine6/bn/moving_mean': 'conv5_up.bn.running_mean', 'op_after_combine6/bn/moving_variance': 'conv5_up.bn.running_var', } if fpn_idx == 0: mapping = { 'resample_0_2_6/conv2d/kernel': 'p5_down_channel.down_conv.weight', 'resample_0_2_6/conv2d/bias': 'p5_down_channel.down_conv.bias', 'resample_0_2_6/bn/beta': 'p5_down_channel.bn.bias', 'resample_0_2_6/bn/gamma': 'p5_down_channel.bn.weight', 'resample_0_2_6/bn/moving_mean': 'p5_down_channel.bn.running_mean', 'resample_0_2_6/bn/moving_variance': 'p5_down_channel.bn.running_var', } base_mapping.update(mapping) if seg[3] != 'WSM' and seg[3] != 'WSM_1': suffix = base_mapping['/'.join(seg[3:])] if 'depthwise_conv' in suffix: v = v.transpose(1, 0) m[prefix + '.' + suffix] = v elif seg[3] == 'WSM': p5_w1[fpn_idx][0] = v elif seg[3] == 'WSM_1': p5_w1[fpn_idx][1] = v if torch.min(p5_w1[fpn_idx]) > -1e4: m[prefix + '.p5_w1'] = p5_w1[fpn_idx] elif fnode_id == 2: base_mapping = { 'op_after_combine7/conv/depthwise_kernel': 'conv4_up.depthwise_conv.weight', 'op_after_combine7/conv/pointwise_kernel': 'conv4_up.pointwise_conv.weight', 'op_after_combine7/conv/bias': 'conv4_up.pointwise_conv.bias', 'op_after_combine7/bn/beta': 'conv4_up.bn.bias', 'op_after_combine7/bn/gamma': 'conv4_up.bn.weight', 'op_after_combine7/bn/moving_mean': 'conv4_up.bn.running_mean', 'op_after_combine7/bn/moving_variance': 'conv4_up.bn.running_var', } if fpn_idx == 0: mapping = { 'resample_0_1_7/conv2d/kernel': 'p4_down_channel.down_conv.weight', 'resample_0_1_7/conv2d/bias': 'p4_down_channel.down_conv.bias', 'resample_0_1_7/bn/beta': 'p4_down_channel.bn.bias', 'resample_0_1_7/bn/gamma': 'p4_down_channel.bn.weight', 'resample_0_1_7/bn/moving_mean': 'p4_down_channel.bn.running_mean', 'resample_0_1_7/bn/moving_variance': 'p4_down_channel.bn.running_var', } base_mapping.update(mapping) if seg[3] != 'WSM' and seg[3] != 'WSM_1': suffix = base_mapping['/'.join(seg[3:])] if 'depthwise_conv' in suffix: v = v.transpose(1, 0) m[prefix + '.' + suffix] = v elif seg[3] == 'WSM': p4_w1[fpn_idx][0] = v elif seg[3] == 'WSM_1': p4_w1[fpn_idx][1] = v if torch.min(p4_w1[fpn_idx]) > -1e4: m[prefix + '.p4_w1'] = p4_w1[fpn_idx] elif fnode_id == 3: base_mapping = { 'op_after_combine8/conv/depthwise_kernel': 'conv3_up.depthwise_conv.weight', 'op_after_combine8/conv/pointwise_kernel': 'conv3_up.pointwise_conv.weight', 'op_after_combine8/conv/bias': 'conv3_up.pointwise_conv.bias', 'op_after_combine8/bn/beta': 'conv3_up.bn.bias', 'op_after_combine8/bn/gamma': 'conv3_up.bn.weight', 'op_after_combine8/bn/moving_mean': 'conv3_up.bn.running_mean', 'op_after_combine8/bn/moving_variance': 'conv3_up.bn.running_var', } if fpn_idx == 0: mapping = { 'resample_0_0_8/conv2d/kernel': 'p3_down_channel.down_conv.weight', 'resample_0_0_8/conv2d/bias': 'p3_down_channel.down_conv.bias', 'resample_0_0_8/bn/beta': 'p3_down_channel.bn.bias', 'resample_0_0_8/bn/gamma': 'p3_down_channel.bn.weight', 'resample_0_0_8/bn/moving_mean': 'p3_down_channel.bn.running_mean', 'resample_0_0_8/bn/moving_variance': 'p3_down_channel.bn.running_var', } base_mapping.update(mapping) if seg[3] != 'WSM' and seg[3] != 'WSM_1': suffix = base_mapping['/'.join(seg[3:])] if 'depthwise_conv' in suffix: v = v.transpose(1, 0) m[prefix + '.' + suffix] = v elif seg[3] == 'WSM': p3_w1[fpn_idx][0] = v elif seg[3] == 'WSM_1': p3_w1[fpn_idx][1] = v if torch.min(p3_w1[fpn_idx]) > -1e4: m[prefix + '.p3_w1'] = p3_w1[fpn_idx] elif fnode_id == 4: base_mapping = { 'op_after_combine9/conv/depthwise_kernel': 'conv4_down.depthwise_conv.weight', 'op_after_combine9/conv/pointwise_kernel': 'conv4_down.pointwise_conv.weight', 'op_after_combine9/conv/bias': 'conv4_down.pointwise_conv.bias', 'op_after_combine9/bn/beta': 'conv4_down.bn.bias', 'op_after_combine9/bn/gamma': 'conv4_down.bn.weight', 'op_after_combine9/bn/moving_mean': 'conv4_down.bn.running_mean', 'op_after_combine9/bn/moving_variance': 'conv4_down.bn.running_var', } if fpn_idx == 0: mapping = { 'resample_0_1_9/conv2d/kernel': 'p4_level_connection.down_conv.weight', 'resample_0_1_9/conv2d/bias': 'p4_level_connection.down_conv.bias', 'resample_0_1_9/bn/beta': 'p4_level_connection.bn.bias', 'resample_0_1_9/bn/gamma': 'p4_level_connection.bn.weight', 'resample_0_1_9/bn/moving_mean': 'p4_level_connection.bn.running_mean', 'resample_0_1_9/bn/moving_variance': 'p4_level_connection.bn.running_var', } base_mapping.update(mapping) if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2': suffix = base_mapping['/'.join(seg[3:])] if 'depthwise_conv' in suffix: v = v.transpose(1, 0) m[prefix + '.' + suffix] = v elif seg[3] == 'WSM': p4_w2[fpn_idx][0] = v elif seg[3] == 'WSM_1': p4_w2[fpn_idx][1] = v elif seg[3] == 'WSM_2': p4_w2[fpn_idx][2] = v if torch.min(p4_w2[fpn_idx]) > -1e4: m[prefix + '.p4_w2'] = p4_w2[fpn_idx] elif fnode_id == 5: base_mapping = { 'op_after_combine10/conv/depthwise_kernel': 'conv5_down.depthwise_conv.weight', 'op_after_combine10/conv/pointwise_kernel': 'conv5_down.pointwise_conv.weight', 'op_after_combine10/conv/bias': 'conv5_down.pointwise_conv.bias', 'op_after_combine10/bn/beta': 'conv5_down.bn.bias', 'op_after_combine10/bn/gamma': 'conv5_down.bn.weight', 'op_after_combine10/bn/moving_mean': 'conv5_down.bn.running_mean', 'op_after_combine10/bn/moving_variance': 'conv5_down.bn.running_var', } if fpn_idx == 0: mapping = { 'resample_0_2_10/conv2d/kernel': 'p5_level_connection.down_conv.weight', 'resample_0_2_10/conv2d/bias': 'p5_level_connection.down_conv.bias', 'resample_0_2_10/bn/beta': 'p5_level_connection.bn.bias', 'resample_0_2_10/bn/gamma': 'p5_level_connection.bn.weight', 'resample_0_2_10/bn/moving_mean': 'p5_level_connection.bn.running_mean', 'resample_0_2_10/bn/moving_variance': 'p5_level_connection.bn.running_var', } base_mapping.update(mapping) if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2': suffix = base_mapping['/'.join(seg[3:])] if 'depthwise_conv' in suffix: v = v.transpose(1, 0) m[prefix + '.' + suffix] = v elif seg[3] == 'WSM': p5_w2[fpn_idx][0] = v elif seg[3] == 'WSM_1': p5_w2[fpn_idx][1] = v elif seg[3] == 'WSM_2': p5_w2[fpn_idx][2] = v if torch.min(p5_w2[fpn_idx]) > -1e4: m[prefix + '.p5_w2'] = p5_w2[fpn_idx] elif fnode_id == 6: base_mapping = { 'op_after_combine11/conv/depthwise_kernel': 'conv6_down.depthwise_conv.weight', 'op_after_combine11/conv/pointwise_kernel': 'conv6_down.pointwise_conv.weight', 'op_after_combine11/conv/bias': 'conv6_down.pointwise_conv.bias', 'op_after_combine11/bn/beta': 'conv6_down.bn.bias', 'op_after_combine11/bn/gamma': 'conv6_down.bn.weight', 'op_after_combine11/bn/moving_mean': 'conv6_down.bn.running_mean', 'op_after_combine11/bn/moving_variance': 'conv6_down.bn.running_var', } if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2': suffix = base_mapping['/'.join(seg[3:])] if 'depthwise_conv' in suffix: v = v.transpose(1, 0) m[prefix + '.' + suffix] = v elif seg[3] == 'WSM': p6_w2[fpn_idx][0] = v elif seg[3] == 'WSM_1': p6_w2[fpn_idx][1] = v elif seg[3] == 'WSM_2': p6_w2[fpn_idx][2] = v if torch.min(p6_w2[fpn_idx]) > -1e4: m[prefix + '.p6_w2'] = p6_w2[fpn_idx] elif fnode_id == 7: base_mapping = { 'op_after_combine12/conv/depthwise_kernel': 'conv7_down.depthwise_conv.weight', 'op_after_combine12/conv/pointwise_kernel': 'conv7_down.pointwise_conv.weight', 'op_after_combine12/conv/bias': 'conv7_down.pointwise_conv.bias', 'op_after_combine12/bn/beta': 'conv7_down.bn.bias', 'op_after_combine12/bn/gamma': 'conv7_down.bn.weight', 'op_after_combine12/bn/moving_mean': 'conv7_down.bn.running_mean', 'op_after_combine12/bn/moving_variance': 'conv7_down.bn.running_var', } if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2': suffix = base_mapping['/'.join(seg[3:])] if 'depthwise_conv' in suffix: v = v.transpose(1, 0) m[prefix + '.' + suffix] = v elif seg[3] == 'WSM': p7_w2[fpn_idx][0] = v elif seg[3] == 'WSM_1': p7_w2[fpn_idx][1] = v if torch.min(p7_w2[fpn_idx]) > -1e4: m[prefix + '.p7_w2'] = p7_w2[fpn_idx] elif seg[0] == 'box_net': if 'box-predict' in seg[1]: prefix = '.'.join(['bbox_head', 'reg_header']) base_mapping = { 'depthwise_kernel': 'depthwise_conv.weight', 'pointwise_kernel': 'pointwise_conv.weight', 'bias': 'pointwise_conv.bias' } suffix = base_mapping['/'.join(seg[2:])] if 'depthwise_conv' in suffix: v = v.transpose(1, 0) m[prefix + '.' + suffix] = v elif 'bn' in seg[1]: bbox_conv_idx = int(seg[1][4]) bbox_bn_idx = int(seg[1][9]) - 3 prefix = '.'.join([ 'bbox_head', 'reg_bn_list', str(bbox_conv_idx), str(bbox_bn_idx) ]) base_mapping = { 'beta': 'bias', 'gamma': 'weight', 'moving_mean': 'running_mean', 'moving_variance': 'running_var' } suffix = base_mapping['/'.join(seg[2:])] m[prefix + '.' + suffix] = v else: bbox_conv_idx = int(seg[1][4]) prefix = '.'.join( ['bbox_head', 'reg_conv_list', str(bbox_conv_idx)]) base_mapping = { 'depthwise_kernel': 'depthwise_conv.weight', 'pointwise_kernel': 'pointwise_conv.weight', 'bias': 'pointwise_conv.bias' } suffix = base_mapping['/'.join(seg[2:])] if 'depthwise_conv' in suffix: v = v.transpose(1, 0) m[prefix + '.' + suffix] = v elif seg[0] == 'class_net': if 'class-predict' in seg[1]: prefix = '.'.join(['bbox_head', 'cls_header']) base_mapping = { 'depthwise_kernel': 'depthwise_conv.weight', 'pointwise_kernel': 'pointwise_conv.weight', 'bias': 'pointwise_conv.bias' } suffix = base_mapping['/'.join(seg[2:])] if 'depthwise_conv' in suffix: v = v.transpose(1, 0) m[prefix + '.' + suffix] = v elif 'bn' in seg[1]: cls_conv_idx = int(seg[1][6]) cls_bn_idx = int(seg[1][11]) - 3 prefix = '.'.join([ 'bbox_head', 'cls_bn_list', str(cls_conv_idx), str(cls_bn_idx) ]) base_mapping = { 'beta': 'bias', 'gamma': 'weight', 'moving_mean': 'running_mean', 'moving_variance': 'running_var' } suffix = base_mapping['/'.join(seg[2:])] m[prefix + '.' + suffix] = v else: cls_conv_idx = int(seg[1][6]) prefix = '.'.join( ['bbox_head', 'cls_conv_list', str(cls_conv_idx)]) base_mapping = { 'depthwise_kernel': 'depthwise_conv.weight', 'pointwise_kernel': 'pointwise_conv.weight', 'bias': 'pointwise_conv.bias' } suffix = base_mapping['/'.join(seg[2:])] if 'depthwise_conv' in suffix: v = v.transpose(1, 0) m[prefix + '.' + suffix] = v return m def parse_args(): parser = argparse.ArgumentParser( description='convert efficientdet weight from tensorflow to pytorch') parser.add_argument( '--backbone', type=str, help='efficientnet model name, like efficientnet-b0') parser.add_argument( '--tensorflow_weight', type=str, help='efficientdet tensorflow weight name, like efficientdet-d0/model') parser.add_argument( '--out_weight', type=str, help='efficientdet pytorch weight name like demo.pth') args = parser.parse_args() return args def main(): args = parse_args() model_name = args.backbone ori_weight_name = args.tensorflow_weight out_name = args.out_weight repeat_map = { 0: 3, 1: 4, 2: 5, 3: 6, 4: 7, 5: 7, 6: 8, 7: 8, } reader = py_checkpoint_reader.NewCheckpointReader(ori_weight_name) weights = { n: torch.as_tensor(tf2pth(reader.get_tensor(n))) for (n, _) in reader.get_variable_to_shape_map().items() } bifpn_repeats = repeat_map[int(model_name[14])] out = convert_key(model_name, bifpn_repeats, weights) result = {'state_dict': out} torch.save(result, out_name) if __name__ == '__main__': main()