mmpose2torchserve.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. import warnings
  4. from argparse import ArgumentParser, Namespace
  5. from tempfile import TemporaryDirectory
  6. import mmcv
  7. import torch
  8. from mmengine.runner import CheckpointLoader
  9. try:
  10. from model_archiver.model_packaging import package_model
  11. from model_archiver.model_packaging_utils import ModelExportUtils
  12. except ImportError:
  13. package_model = None
  14. def mmpose2torchserve(config_file: str,
  15. checkpoint_file: str,
  16. output_folder: str,
  17. model_name: str,
  18. model_version: str = '1.0',
  19. force: bool = False):
  20. """Converts MMPose model (config + checkpoint) to TorchServe `.mar`.
  21. Args:
  22. config_file:
  23. In MMPose config format.
  24. The contents vary for each task repository.
  25. checkpoint_file:
  26. In MMPose checkpoint format.
  27. The contents vary for each task repository.
  28. output_folder:
  29. Folder where `{model_name}.mar` will be created.
  30. The file created will be in TorchServe archive format.
  31. model_name:
  32. If not None, used for naming the `{model_name}.mar` file
  33. that will be created under `output_folder`.
  34. If None, `{Path(checkpoint_file).stem}` will be used.
  35. model_version:
  36. Model's version.
  37. force:
  38. If True, if there is an existing `{model_name}.mar`
  39. file under `output_folder` it will be overwritten.
  40. """
  41. mmcv.mkdir_or_exist(output_folder)
  42. config = mmcv.Config.fromfile(config_file)
  43. with TemporaryDirectory() as tmpdir:
  44. model_file = osp.join(tmpdir, 'config.py')
  45. config.dump(model_file)
  46. handler_path = osp.join(osp.dirname(__file__), 'mmpose_handler.py')
  47. model_name = model_name or osp.splitext(
  48. osp.basename(checkpoint_file))[0]
  49. # use mmcv CheckpointLoader if checkpoint is not from a local file
  50. if not osp.isfile(checkpoint_file):
  51. ckpt = CheckpointLoader.load_checkpoint(checkpoint_file)
  52. checkpoint_file = osp.join(tmpdir, 'checkpoint.pth')
  53. with open(checkpoint_file, 'wb') as f:
  54. torch.save(ckpt, f)
  55. args = Namespace(
  56. **{
  57. 'model_file': model_file,
  58. 'serialized_file': checkpoint_file,
  59. 'handler': handler_path,
  60. 'model_name': model_name,
  61. 'version': model_version,
  62. 'export_path': output_folder,
  63. 'force': force,
  64. 'requirements_file': None,
  65. 'extra_files': None,
  66. 'runtime': 'python',
  67. 'archive_format': 'default'
  68. })
  69. manifest = ModelExportUtils.generate_manifest_json(args)
  70. package_model(args, manifest)
  71. def parse_args():
  72. parser = ArgumentParser(
  73. description='Convert MMPose models to TorchServe `.mar` format.')
  74. parser.add_argument('config', type=str, help='config file path')
  75. parser.add_argument('checkpoint', type=str, help='checkpoint file path')
  76. parser.add_argument(
  77. '--output-folder',
  78. type=str,
  79. required=True,
  80. help='Folder where `{model_name}.mar` will be created.')
  81. parser.add_argument(
  82. '--model-name',
  83. type=str,
  84. default=None,
  85. help='If not None, used for naming the `{model_name}.mar`'
  86. 'file that will be created under `output_folder`.'
  87. 'If None, `{Path(checkpoint_file).stem}` will be used.')
  88. parser.add_argument(
  89. '--model-version',
  90. type=str,
  91. default='1.0',
  92. help='Number used for versioning.')
  93. parser.add_argument(
  94. '-f',
  95. '--force',
  96. action='store_true',
  97. help='overwrite the existing `{model_name}.mar`')
  98. args = parser.parse_args()
  99. return args
  100. if __name__ == '__main__':
  101. args = parse_args()
  102. # Following strings of text style are from colorama package
  103. bright_style, reset_style = '\x1b[1m', '\x1b[0m'
  104. red_text, blue_text = '\x1b[31m', '\x1b[34m'
  105. white_background = '\x1b[107m'
  106. msg = white_background + bright_style + red_text
  107. msg += 'DeprecationWarning: This tool will be deprecated in future. '
  108. msg += blue_text + 'Welcome to use the unified model deployment toolbox '
  109. msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
  110. msg += reset_style
  111. warnings.warn(msg)
  112. if package_model is None:
  113. raise ImportError('`torch-model-archiver` is required.'
  114. 'Try: pip install torch-model-archiver')
  115. mmpose2torchserve(args.config, args.checkpoint, args.output_folder,
  116. args.model_name, args.model_version, args.force)