inferencer_demo.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from argparse import ArgumentParser
  3. from typing import Dict
  4. from mmpose.apis.inferencers import MMPoseInferencer, get_model_aliases
  5. def parse_args():
  6. parser = ArgumentParser()
  7. parser.add_argument(
  8. 'inputs',
  9. type=str,
  10. nargs='?',
  11. help='Input image/video path or folder path.')
  12. parser.add_argument(
  13. '--pose2d',
  14. type=str,
  15. default=None,
  16. help='Pretrained 2D pose estimation algorithm. It\'s the path to the '
  17. 'config file or the model name defined in metafile.')
  18. parser.add_argument(
  19. '--pose2d-weights',
  20. type=str,
  21. default=None,
  22. help='Path to the custom checkpoint file of the selected pose model. '
  23. 'If it is not specified and "pose2d" is a model name of metafile, '
  24. 'the weights will be loaded from metafile.')
  25. parser.add_argument(
  26. '--det-model',
  27. type=str,
  28. default=None,
  29. help='Config path or alias of detection model.')
  30. parser.add_argument(
  31. '--det-weights',
  32. type=str,
  33. default=None,
  34. help='Path to the checkpoints of detection model.')
  35. parser.add_argument(
  36. '--det-cat-ids',
  37. type=int,
  38. nargs='+',
  39. default=0,
  40. help='Category id for detection model.')
  41. parser.add_argument(
  42. '--scope',
  43. type=str,
  44. default='mmpose',
  45. help='Scope where modules are defined.')
  46. parser.add_argument(
  47. '--device',
  48. type=str,
  49. default=None,
  50. help='Device used for inference. '
  51. 'If not specified, the available device will be automatically used.')
  52. parser.add_argument(
  53. '--show',
  54. action='store_true',
  55. help='Display the image/video in a popup window.')
  56. parser.add_argument(
  57. '--draw-bbox',
  58. action='store_true',
  59. help='Whether to draw the bounding boxes.')
  60. parser.add_argument(
  61. '--draw-heatmap',
  62. action='store_true',
  63. default=False,
  64. help='Whether to draw the predicted heatmaps.')
  65. parser.add_argument(
  66. '--bbox-thr',
  67. type=float,
  68. default=0.3,
  69. help='Bounding box score threshold')
  70. parser.add_argument(
  71. '--nms-thr',
  72. type=float,
  73. default=0.3,
  74. help='IoU threshold for bounding box NMS')
  75. parser.add_argument(
  76. '--kpt-thr', type=float, default=0.3, help='Keypoint score threshold')
  77. parser.add_argument(
  78. '--radius',
  79. type=int,
  80. default=3,
  81. help='Keypoint radius for visualization.')
  82. parser.add_argument(
  83. '--thickness',
  84. type=int,
  85. default=1,
  86. help='Link thickness for visualization.')
  87. parser.add_argument(
  88. '--vis-out-dir',
  89. type=str,
  90. default='',
  91. help='Directory for saving visualized results.')
  92. parser.add_argument(
  93. '--pred-out-dir',
  94. type=str,
  95. default='',
  96. help='Directory for saving inference results.')
  97. parser.add_argument(
  98. '--show-alias',
  99. action='store_true',
  100. help='Display all the available model aliases.')
  101. call_args = vars(parser.parse_args())
  102. init_kws = [
  103. 'pose2d', 'pose2d_weights', 'scope', 'device', 'det_model',
  104. 'det_weights', 'det_cat_ids'
  105. ]
  106. init_args = {}
  107. init_args['output_heatmaps'] = call_args.pop('draw_heatmap')
  108. for init_kw in init_kws:
  109. init_args[init_kw] = call_args.pop(init_kw)
  110. diaplay_alias = call_args.pop('show_alias')
  111. return init_args, call_args, diaplay_alias
  112. def display_model_aliases(model_aliases: Dict[str, str]) -> None:
  113. """Display the available model aliases and their corresponding model
  114. names."""
  115. aliases = list(model_aliases.keys())
  116. max_alias_length = max(map(len, aliases))
  117. print(f'{"ALIAS".ljust(max_alias_length+2)}MODEL_NAME')
  118. for alias in sorted(aliases):
  119. print(f'{alias.ljust(max_alias_length+2)}{model_aliases[alias]}')
  120. def main():
  121. init_args, call_args, diaplay_alias = parse_args()
  122. if diaplay_alias:
  123. model_alises = get_model_aliases(init_args['scope'])
  124. display_model_aliases(model_alises)
  125. else:
  126. inferencer = MMPoseInferencer(**init_args)
  127. for _ in inferencer(**call_args):
  128. pass
  129. if __name__ == '__main__':
  130. main()