_wsgi.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import json
  4. import logging
  5. import logging.config
  6. import os
  7. logging.config.dictConfig({
  8. 'version': 1,
  9. 'formatters': {
  10. 'standard': {
  11. 'format':
  12. '[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s' # noqa E501
  13. }
  14. },
  15. 'handlers': {
  16. 'console': {
  17. 'class': 'logging.StreamHandler',
  18. 'level': 'DEBUG',
  19. 'stream': 'ext://sys.stdout',
  20. 'formatter': 'standard'
  21. }
  22. },
  23. 'root': {
  24. 'level': 'ERROR',
  25. 'handlers': ['console'],
  26. 'propagate': True
  27. }
  28. })
  29. _DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), 'config.json')
  30. def get_kwargs_from_config(config_path=_DEFAULT_CONFIG_PATH):
  31. if not os.path.exists(config_path):
  32. return dict()
  33. with open(config_path) as f:
  34. config = json.load(f)
  35. assert isinstance(config, dict)
  36. return config
  37. if __name__ == '__main__':
  38. from label_studio_ml.api import init_app
  39. from projects.LabelStudio.backend_template.mmdetection import MMDetection
  40. parser = argparse.ArgumentParser(description='Label studio')
  41. parser.add_argument(
  42. '-p',
  43. '--port',
  44. dest='port',
  45. type=int,
  46. default=9090,
  47. help='Server port')
  48. parser.add_argument(
  49. '--host', dest='host', type=str, default='0.0.0.0', help='Server host')
  50. parser.add_argument(
  51. '--kwargs',
  52. '--with',
  53. dest='kwargs',
  54. metavar='KEY=VAL',
  55. nargs='+',
  56. type=lambda kv: kv.split('='),
  57. help='Additional LabelStudioMLBase model initialization kwargs')
  58. parser.add_argument(
  59. '-d',
  60. '--debug',
  61. dest='debug',
  62. action='store_true',
  63. help='Switch debug mode')
  64. parser.add_argument(
  65. '--log-level',
  66. dest='log_level',
  67. choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
  68. default=None,
  69. help='Logging level')
  70. parser.add_argument(
  71. '--model-dir',
  72. dest='model_dir',
  73. default=os.path.dirname(__file__),
  74. help='Directory models are store',
  75. )
  76. parser.add_argument(
  77. '--check',
  78. dest='check',
  79. action='store_true',
  80. help='Validate model instance before launching server')
  81. args = parser.parse_args()
  82. # setup logging level
  83. if args.log_level:
  84. logging.root.setLevel(args.log_level)
  85. def isfloat(value):
  86. try:
  87. float(value)
  88. return True
  89. except ValueError:
  90. return False
  91. def parse_kwargs():
  92. param = dict()
  93. for k, v in args.kwargs:
  94. if v.isdigit():
  95. param[k] = int(v)
  96. elif v == 'True' or v == 'true':
  97. param[k] = True
  98. elif v == 'False' or v == 'False':
  99. param[k] = False
  100. elif isfloat(v):
  101. param[k] = float(v)
  102. else:
  103. param[k] = v
  104. return param
  105. kwargs = get_kwargs_from_config()
  106. if args.kwargs:
  107. kwargs.update(parse_kwargs())
  108. if args.check:
  109. print('Check "' + MMDetection.__name__ + '" instance creation..')
  110. model = MMDetection(**kwargs)
  111. app = init_app(
  112. model_class=MMDetection,
  113. model_dir=os.environ.get('MODEL_DIR', args.model_dir),
  114. redis_queue=os.environ.get('RQ_QUEUE_NAME', 'default'),
  115. redis_host=os.environ.get('REDIS_HOST', 'localhost'),
  116. redis_port=os.environ.get('REDIS_PORT', 6379),
  117. **kwargs)
  118. app.run(host=args.host, port=args.port, debug=args.debug)
  119. else:
  120. # for uWSGI use
  121. app = init_app(
  122. model_class=MMDetection,
  123. model_dir=os.environ.get('MODEL_DIR', os.path.dirname(__file__)),
  124. redis_queue=os.environ.get('RQ_QUEUE_NAME', 'default'),
  125. redis_host=os.environ.get('REDIS_HOST', 'localhost'),
  126. redis_port=os.environ.get('REDIS_PORT', 6379))