# Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import warnings from argparse import ArgumentParser, Namespace from tempfile import TemporaryDirectory import mmcv import torch from mmengine.runner import CheckpointLoader try: from model_archiver.model_packaging import package_model from model_archiver.model_packaging_utils import ModelExportUtils except ImportError: package_model = None def mmpose2torchserve(config_file: str, checkpoint_file: str, output_folder: str, model_name: str, model_version: str = '1.0', force: bool = False): """Converts MMPose model (config + checkpoint) to TorchServe `.mar`. Args: config_file: In MMPose config format. The contents vary for each task repository. checkpoint_file: In MMPose checkpoint format. The contents vary for each task repository. output_folder: Folder where `{model_name}.mar` will be created. The file created will be in TorchServe archive format. model_name: If not None, used for naming the `{model_name}.mar` file that will be created under `output_folder`. If None, `{Path(checkpoint_file).stem}` will be used. model_version: Model's version. force: If True, if there is an existing `{model_name}.mar` file under `output_folder` it will be overwritten. """ mmcv.mkdir_or_exist(output_folder) config = mmcv.Config.fromfile(config_file) with TemporaryDirectory() as tmpdir: model_file = osp.join(tmpdir, 'config.py') config.dump(model_file) handler_path = osp.join(osp.dirname(__file__), 'mmpose_handler.py') model_name = model_name or osp.splitext( osp.basename(checkpoint_file))[0] # use mmcv CheckpointLoader if checkpoint is not from a local file if not osp.isfile(checkpoint_file): ckpt = CheckpointLoader.load_checkpoint(checkpoint_file) checkpoint_file = osp.join(tmpdir, 'checkpoint.pth') with open(checkpoint_file, 'wb') as f: torch.save(ckpt, f) args = Namespace( **{ 'model_file': model_file, 'serialized_file': checkpoint_file, 'handler': handler_path, 'model_name': model_name, 'version': model_version, 'export_path': output_folder, 'force': force, 'requirements_file': None, 'extra_files': None, 'runtime': 'python', 'archive_format': 'default' }) manifest = ModelExportUtils.generate_manifest_json(args) package_model(args, manifest) def parse_args(): parser = ArgumentParser( description='Convert MMPose models to TorchServe `.mar` format.') parser.add_argument('config', type=str, help='config file path') parser.add_argument('checkpoint', type=str, help='checkpoint file path') parser.add_argument( '--output-folder', type=str, required=True, help='Folder where `{model_name}.mar` will be created.') parser.add_argument( '--model-name', type=str, default=None, help='If not None, used for naming the `{model_name}.mar`' 'file that will be created under `output_folder`.' 'If None, `{Path(checkpoint_file).stem}` will be used.') parser.add_argument( '--model-version', type=str, default='1.0', help='Number used for versioning.') parser.add_argument( '-f', '--force', action='store_true', help='overwrite the existing `{model_name}.mar`') args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() # Following strings of text style are from colorama package bright_style, reset_style = '\x1b[1m', '\x1b[0m' red_text, blue_text = '\x1b[31m', '\x1b[34m' white_background = '\x1b[107m' msg = white_background + bright_style + red_text msg += 'DeprecationWarning: This tool will be deprecated in future. ' msg += blue_text + 'Welcome to use the unified model deployment toolbox ' msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy' msg += reset_style warnings.warn(msg) if package_model is None: raise ImportError('`torch-model-archiver` is required.' 'Try: pip install torch-model-archiver') mmpose2torchserve(args.config, args.checkpoint, args.output_folder, args.model_name, args.model_version, args.force)