mmdet2torchserve.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from argparse import ArgumentParser, Namespace
  3. from pathlib import Path
  4. from tempfile import TemporaryDirectory
  5. from mmengine.config import Config
  6. from mmengine.utils import mkdir_or_exist
  7. try:
  8. from model_archiver.model_packaging import package_model
  9. from model_archiver.model_packaging_utils import ModelExportUtils
  10. except ImportError:
  11. package_model = None
  12. def mmdet2torchserve(
  13. config_file: str,
  14. checkpoint_file: str,
  15. output_folder: str,
  16. model_name: str,
  17. model_version: str = '1.0',
  18. force: bool = False,
  19. ):
  20. """Converts MMDetection model (config + checkpoint) to TorchServe `.mar`.
  21. Args:
  22. config_file:
  23. In MMDetection config format.
  24. The contents vary for each task repository.
  25. checkpoint_file:
  26. In MMDetection checkpoint format.
  27. The contents vary for each task repository.
  28. output_folder:
  29. Folder where `{model_name}.mar` will be created.
  30. The file created will be in TorchServe archive format.
  31. model_name:
  32. If not None, used for naming the `{model_name}.mar` file
  33. that will be created under `output_folder`.
  34. If None, `{Path(checkpoint_file).stem}` will be used.
  35. model_version:
  36. Model's version.
  37. force:
  38. If True, if there is an existing `{model_name}.mar`
  39. file under `output_folder` it will be overwritten.
  40. """
  41. mkdir_or_exist(output_folder)
  42. config = Config.fromfile(config_file)
  43. with TemporaryDirectory() as tmpdir:
  44. config.dump(f'{tmpdir}/config.py')
  45. args = Namespace(
  46. **{
  47. 'model_file': f'{tmpdir}/config.py',
  48. 'serialized_file': checkpoint_file,
  49. 'handler': f'{Path(__file__).parent}/mmdet_handler.py',
  50. 'model_name': model_name or Path(checkpoint_file).stem,
  51. 'version': model_version,
  52. 'export_path': output_folder,
  53. 'force': force,
  54. 'requirements_file': None,
  55. 'extra_files': None,
  56. 'runtime': 'python',
  57. 'archive_format': 'default'
  58. })
  59. manifest = ModelExportUtils.generate_manifest_json(args)
  60. package_model(args, manifest)
  61. def parse_args():
  62. parser = ArgumentParser(
  63. description='Convert MMDetection models to TorchServe `.mar` format.')
  64. parser.add_argument('config', type=str, help='config file path')
  65. parser.add_argument('checkpoint', type=str, help='checkpoint file path')
  66. parser.add_argument(
  67. '--output-folder',
  68. type=str,
  69. required=True,
  70. help='Folder where `{model_name}.mar` will be created.')
  71. parser.add_argument(
  72. '--model-name',
  73. type=str,
  74. default=None,
  75. help='If not None, used for naming the `{model_name}.mar`'
  76. 'file that will be created under `output_folder`.'
  77. 'If None, `{Path(checkpoint_file).stem}` will be used.')
  78. parser.add_argument(
  79. '--model-version',
  80. type=str,
  81. default='1.0',
  82. help='Number used for versioning.')
  83. parser.add_argument(
  84. '-f',
  85. '--force',
  86. action='store_true',
  87. help='overwrite the existing `{model_name}.mar`')
  88. args = parser.parse_args()
  89. return args
  90. if __name__ == '__main__':
  91. args = parse_args()
  92. if package_model is None:
  93. raise ImportError('`torch-model-archiver` is required.'
  94. 'Try: pip install torch-model-archiver')
  95. mmdet2torchserve(args.config, args.checkpoint, args.output_folder,
  96. args.model_name, args.model_version, args.force)