123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from argparse import ArgumentParser
- from typing import Dict
- from mmpose.apis.inferencers import MMPoseInferencer, get_model_aliases
- def parse_args():
- parser = ArgumentParser()
- parser.add_argument(
- 'inputs',
- type=str,
- nargs='?',
- help='Input image/video path or folder path.')
- parser.add_argument(
- '--pose2d',
- type=str,
- default=None,
- help='Pretrained 2D pose estimation algorithm. It\'s the path to the '
- 'config file or the model name defined in metafile.')
- parser.add_argument(
- '--pose2d-weights',
- type=str,
- default=None,
- help='Path to the custom checkpoint file of the selected pose model. '
- 'If it is not specified and "pose2d" is a model name of metafile, '
- 'the weights will be loaded from metafile.')
- parser.add_argument(
- '--det-model',
- type=str,
- default=None,
- help='Config path or alias of detection model.')
- parser.add_argument(
- '--det-weights',
- type=str,
- default=None,
- help='Path to the checkpoints of detection model.')
- parser.add_argument(
- '--det-cat-ids',
- type=int,
- nargs='+',
- default=0,
- help='Category id for detection model.')
- parser.add_argument(
- '--scope',
- type=str,
- default='mmpose',
- help='Scope where modules are defined.')
- parser.add_argument(
- '--device',
- type=str,
- default=None,
- help='Device used for inference. '
- 'If not specified, the available device will be automatically used.')
- parser.add_argument(
- '--show',
- action='store_true',
- help='Display the image/video in a popup window.')
- parser.add_argument(
- '--draw-bbox',
- action='store_true',
- help='Whether to draw the bounding boxes.')
- parser.add_argument(
- '--draw-heatmap',
- action='store_true',
- default=False,
- help='Whether to draw the predicted heatmaps.')
- parser.add_argument(
- '--bbox-thr',
- type=float,
- default=0.3,
- help='Bounding box score threshold')
- parser.add_argument(
- '--nms-thr',
- type=float,
- default=0.3,
- help='IoU threshold for bounding box NMS')
- parser.add_argument(
- '--kpt-thr', type=float, default=0.3, help='Keypoint score threshold')
- parser.add_argument(
- '--radius',
- type=int,
- default=3,
- help='Keypoint radius for visualization.')
- parser.add_argument(
- '--thickness',
- type=int,
- default=1,
- help='Link thickness for visualization.')
- parser.add_argument(
- '--vis-out-dir',
- type=str,
- default='',
- help='Directory for saving visualized results.')
- parser.add_argument(
- '--pred-out-dir',
- type=str,
- default='',
- help='Directory for saving inference results.')
- parser.add_argument(
- '--show-alias',
- action='store_true',
- help='Display all the available model aliases.')
- call_args = vars(parser.parse_args())
- init_kws = [
- 'pose2d', 'pose2d_weights', 'scope', 'device', 'det_model',
- 'det_weights', 'det_cat_ids'
- ]
- init_args = {}
- init_args['output_heatmaps'] = call_args.pop('draw_heatmap')
- for init_kw in init_kws:
- init_args[init_kw] = call_args.pop(init_kw)
- diaplay_alias = call_args.pop('show_alias')
- return init_args, call_args, diaplay_alias
- def display_model_aliases(model_aliases: Dict[str, str]) -> None:
- """Display the available model aliases and their corresponding model
- names."""
- aliases = list(model_aliases.keys())
- max_alias_length = max(map(len, aliases))
- print(f'{"ALIAS".ljust(max_alias_length+2)}MODEL_NAME')
- for alias in sorted(aliases):
- print(f'{alias.ljust(max_alias_length+2)}{model_aliases[alias]}')
- def main():
- init_args, call_args, diaplay_alias = parse_args()
- if diaplay_alias:
- model_alises = get_model_aliases(init_args['scope'])
- display_model_aliases(model_alises)
- else:
- inferencer = MMPoseInferencer(**init_args)
- for _ in inferencer(**call_args):
- pass
- if __name__ == '__main__':
- main()
|