diffusiondet_resnet_to_mmdet.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. from collections import OrderedDict
  4. import numpy as np
  5. import torch
  6. from mmengine.fileio import load
  7. def convert(src, dst):
  8. if src.endswith('pth'):
  9. src_model = torch.load(src)
  10. else:
  11. src_model = load(src)
  12. dst_state_dict = OrderedDict()
  13. for k, v in src_model['model'].items():
  14. key_name_split = k.split('.')
  15. if 'backbone.fpn_lateral' in k:
  16. lateral_id = int(key_name_split[-2][-1])
  17. name = f'neck.lateral_convs.{lateral_id - 2}.' \
  18. f'conv.{key_name_split[-1]}'
  19. elif 'backbone.fpn_output' in k:
  20. lateral_id = int(key_name_split[-2][-1])
  21. name = f'neck.fpn_convs.{lateral_id - 2}.conv.' \
  22. f'{key_name_split[-1]}'
  23. elif 'backbone.bottom_up.stem.conv1.norm.' in k:
  24. name = f'backbone.bn1.{key_name_split[-1]}'
  25. elif 'backbone.bottom_up.stem.conv1.' in k:
  26. name = f'backbone.conv1.{key_name_split[-1]}'
  27. elif 'backbone.bottom_up.res' in k:
  28. # weight_type = key_name_split[-1]
  29. res_id = int(key_name_split[2][-1]) - 1
  30. # deal with short cut
  31. if 'shortcut' in key_name_split[4]:
  32. if 'shortcut' == key_name_split[-2]:
  33. name = f'backbone.layer{res_id}.' \
  34. f'{key_name_split[3]}.downsample.0.' \
  35. f'{key_name_split[-1]}'
  36. elif 'shortcut' == key_name_split[-3]:
  37. name = f'backbone.layer{res_id}.' \
  38. f'{key_name_split[3]}.downsample.1.' \
  39. f'{key_name_split[-1]}'
  40. else:
  41. print(f'Unvalid key {k}')
  42. # deal with conv
  43. elif 'conv' in key_name_split[-2]:
  44. conv_id = int(key_name_split[-2][-1])
  45. name = f'backbone.layer{res_id}.{key_name_split[3]}' \
  46. f'.conv{conv_id}.{key_name_split[-1]}'
  47. # deal with BN
  48. elif key_name_split[-2] == 'norm':
  49. conv_id = int(key_name_split[-3][-1])
  50. name = f'backbone.layer{res_id}.{key_name_split[3]}.' \
  51. f'bn{conv_id}.{key_name_split[-1]}'
  52. else:
  53. print(f'{k} is invalid')
  54. elif key_name_split[0] == 'head':
  55. # d2: head.xxx -> mmdet: bbox_head.xxx
  56. name = f'bbox_{k}'
  57. else:
  58. # some base parameters such as beta will not convert
  59. print(f'{k} is not converted!!')
  60. continue
  61. if not isinstance(v, np.ndarray) and not isinstance(v, torch.Tensor):
  62. raise ValueError(
  63. 'Unsupported type found in checkpoint! {}: {}'.format(
  64. k, type(v)))
  65. if not isinstance(v, torch.Tensor):
  66. dst_state_dict[name] = torch.from_numpy(v)
  67. else:
  68. dst_state_dict[name] = v
  69. mmdet_model = dict(state_dict=dst_state_dict, meta=dict())
  70. torch.save(mmdet_model, dst)
  71. def main():
  72. parser = argparse.ArgumentParser(description='Convert model keys')
  73. parser.add_argument('src', help='src detectron model path')
  74. parser.add_argument('dst', help='save path')
  75. args = parser.parse_args()
  76. convert(args.src, args.dst)
  77. if __name__ == '__main__':
  78. main()