upgrade_model_version.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import re
  4. import tempfile
  5. from collections import OrderedDict
  6. import torch
  7. from mmengine import Config
  8. def is_head(key):
  9. valid_head_list = [
  10. 'bbox_head', 'mask_head', 'semantic_head', 'grid_head', 'mask_iou_head'
  11. ]
  12. return any(key.startswith(h) for h in valid_head_list)
  13. def parse_config(config_strings):
  14. temp_file = tempfile.NamedTemporaryFile()
  15. config_path = f'{temp_file.name}.py'
  16. with open(config_path, 'w') as f:
  17. f.write(config_strings)
  18. config = Config.fromfile(config_path)
  19. is_two_stage = True
  20. is_ssd = False
  21. is_retina = False
  22. reg_cls_agnostic = False
  23. if 'rpn_head' not in config.model:
  24. is_two_stage = False
  25. # check whether it is SSD
  26. if config.model.bbox_head.type == 'SSDHead':
  27. is_ssd = True
  28. elif config.model.bbox_head.type == 'RetinaHead':
  29. is_retina = True
  30. elif isinstance(config.model['bbox_head'], list):
  31. reg_cls_agnostic = True
  32. elif 'reg_class_agnostic' in config.model.bbox_head:
  33. reg_cls_agnostic = config.model.bbox_head \
  34. .reg_class_agnostic
  35. temp_file.close()
  36. return is_two_stage, is_ssd, is_retina, reg_cls_agnostic
  37. def reorder_cls_channel(val, num_classes=81):
  38. # bias
  39. if val.dim() == 1:
  40. new_val = torch.cat((val[1:], val[:1]), dim=0)
  41. # weight
  42. else:
  43. out_channels, in_channels = val.shape[:2]
  44. # conv_cls for softmax output
  45. if out_channels != num_classes and out_channels % num_classes == 0:
  46. new_val = val.reshape(-1, num_classes, in_channels, *val.shape[2:])
  47. new_val = torch.cat((new_val[:, 1:], new_val[:, :1]), dim=1)
  48. new_val = new_val.reshape(val.size())
  49. # fc_cls
  50. elif out_channels == num_classes:
  51. new_val = torch.cat((val[1:], val[:1]), dim=0)
  52. # agnostic | retina_cls | rpn_cls
  53. else:
  54. new_val = val
  55. return new_val
  56. def truncate_cls_channel(val, num_classes=81):
  57. # bias
  58. if val.dim() == 1:
  59. if val.size(0) % num_classes == 0:
  60. new_val = val[:num_classes - 1]
  61. else:
  62. new_val = val
  63. # weight
  64. else:
  65. out_channels, in_channels = val.shape[:2]
  66. # conv_logits
  67. if out_channels % num_classes == 0:
  68. new_val = val.reshape(num_classes, in_channels, *val.shape[2:])[1:]
  69. new_val = new_val.reshape(-1, *val.shape[1:])
  70. # agnostic
  71. else:
  72. new_val = val
  73. return new_val
  74. def truncate_reg_channel(val, num_classes=81):
  75. # bias
  76. if val.dim() == 1:
  77. # fc_reg | rpn_reg
  78. if val.size(0) % num_classes == 0:
  79. new_val = val.reshape(num_classes, -1)[:num_classes - 1]
  80. new_val = new_val.reshape(-1)
  81. # agnostic
  82. else:
  83. new_val = val
  84. # weight
  85. else:
  86. out_channels, in_channels = val.shape[:2]
  87. # fc_reg | rpn_reg
  88. if out_channels % num_classes == 0:
  89. new_val = val.reshape(num_classes, -1, in_channels,
  90. *val.shape[2:])[1:]
  91. new_val = new_val.reshape(-1, *val.shape[1:])
  92. # agnostic
  93. else:
  94. new_val = val
  95. return new_val
  96. def convert(in_file, out_file, num_classes):
  97. """Convert keys in checkpoints.
  98. There can be some breaking changes during the development of mmdetection,
  99. and this tool is used for upgrading checkpoints trained with old versions
  100. to the latest one.
  101. """
  102. checkpoint = torch.load(in_file)
  103. in_state_dict = checkpoint.pop('state_dict')
  104. out_state_dict = OrderedDict()
  105. meta_info = checkpoint['meta']
  106. is_two_stage, is_ssd, is_retina, reg_cls_agnostic = parse_config(
  107. '#' + meta_info['config'])
  108. if meta_info['mmdet_version'] <= '0.5.3' and is_retina:
  109. upgrade_retina = True
  110. else:
  111. upgrade_retina = False
  112. # MMDetection v2.5.0 unifies the class order in RPN
  113. # if the model is trained in version<v2.5.0
  114. # The RPN model should be upgraded to be used in version>=2.5.0
  115. if meta_info['mmdet_version'] < '2.5.0':
  116. upgrade_rpn = True
  117. else:
  118. upgrade_rpn = False
  119. for key, val in in_state_dict.items():
  120. new_key = key
  121. new_val = val
  122. if is_two_stage and is_head(key):
  123. new_key = 'roi_head.{}'.format(key)
  124. # classification
  125. if upgrade_rpn:
  126. m = re.search(
  127. r'(conv_cls|retina_cls|rpn_cls|fc_cls|fcos_cls|'
  128. r'fovea_cls).(weight|bias)', new_key)
  129. else:
  130. m = re.search(
  131. r'(conv_cls|retina_cls|fc_cls|fcos_cls|'
  132. r'fovea_cls).(weight|bias)', new_key)
  133. if m is not None:
  134. print(f'reorder cls channels of {new_key}')
  135. new_val = reorder_cls_channel(val, num_classes)
  136. # regression
  137. if upgrade_rpn:
  138. m = re.search(r'(fc_reg).(weight|bias)', new_key)
  139. else:
  140. m = re.search(r'(fc_reg|rpn_reg).(weight|bias)', new_key)
  141. if m is not None and not reg_cls_agnostic:
  142. print(f'truncate regression channels of {new_key}')
  143. new_val = truncate_reg_channel(val, num_classes)
  144. # mask head
  145. m = re.search(r'(conv_logits).(weight|bias)', new_key)
  146. if m is not None:
  147. print(f'truncate mask prediction channels of {new_key}')
  148. new_val = truncate_cls_channel(val, num_classes)
  149. m = re.search(r'(cls_convs|reg_convs).\d.(weight|bias)', key)
  150. # Legacy issues in RetinaNet since V1.x
  151. # Use ConvModule instead of nn.Conv2d in RetinaNet
  152. # cls_convs.0.weight -> cls_convs.0.conv.weight
  153. if m is not None and upgrade_retina:
  154. param = m.groups()[1]
  155. new_key = key.replace(param, f'conv.{param}')
  156. out_state_dict[new_key] = val
  157. print(f'rename the name of {key} to {new_key}')
  158. continue
  159. m = re.search(r'(cls_convs).\d.(weight|bias)', key)
  160. if m is not None and is_ssd:
  161. print(f'reorder cls channels of {new_key}')
  162. new_val = reorder_cls_channel(val, num_classes)
  163. out_state_dict[new_key] = new_val
  164. checkpoint['state_dict'] = out_state_dict
  165. torch.save(checkpoint, out_file)
  166. def main():
  167. parser = argparse.ArgumentParser(description='Upgrade model version')
  168. parser.add_argument('in_file', help='input checkpoint file')
  169. parser.add_argument('out_file', help='output checkpoint file')
  170. parser.add_argument(
  171. '--num-classes',
  172. type=int,
  173. default=81,
  174. help='number of classes of the original model')
  175. args = parser.parse_args()
  176. convert(args.in_file, args.out_file, args.num_classes)
  177. if __name__ == '__main__':
  178. main()