test_torchserver.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import os
  2. from argparse import ArgumentParser
  3. import mmcv
  4. import requests
  5. import torch
  6. from mmengine.structures import InstanceData
  7. from mmdet.apis import inference_detector, init_detector
  8. from mmdet.registry import VISUALIZERS
  9. from mmdet.structures import DetDataSample
  10. def parse_args():
  11. parser = ArgumentParser()
  12. parser.add_argument('img', help='Image file')
  13. parser.add_argument('config', help='Config file')
  14. parser.add_argument('checkpoint', help='Checkpoint file')
  15. parser.add_argument('model_name', help='The model name in the server')
  16. parser.add_argument(
  17. '--inference-addr',
  18. default='127.0.0.1:8080',
  19. help='Address and port of the inference server')
  20. parser.add_argument(
  21. '--device', default='cuda:0', help='Device used for inference')
  22. parser.add_argument(
  23. '--score-thr', type=float, default=0.5, help='bbox score threshold')
  24. parser.add_argument(
  25. '--work-dir',
  26. type=str,
  27. default=None,
  28. help='output directory to save drawn results.')
  29. args = parser.parse_args()
  30. return args
  31. def align_ts_output(inputs, metainfo, device):
  32. bboxes = []
  33. labels = []
  34. scores = []
  35. for i, pred in enumerate(inputs):
  36. bboxes.append(pred['bbox'])
  37. labels.append(pred['class_label'])
  38. scores.append(pred['score'])
  39. pred_instances = InstanceData(metainfo=metainfo)
  40. pred_instances.bboxes = torch.tensor(
  41. bboxes, dtype=torch.float32, device=device)
  42. pred_instances.labels = torch.tensor(
  43. labels, dtype=torch.int64, device=device)
  44. pred_instances.scores = torch.tensor(
  45. scores, dtype=torch.float32, device=device)
  46. ts_data_sample = DetDataSample(pred_instances=pred_instances)
  47. return ts_data_sample
  48. def main(args):
  49. # build the model from a config file and a checkpoint file
  50. model = init_detector(args.config, args.checkpoint, device=args.device)
  51. # test a single image
  52. pytorch_results = inference_detector(model, args.img)
  53. keep = pytorch_results.pred_instances.scores >= args.score_thr
  54. pytorch_results.pred_instances = pytorch_results.pred_instances[keep]
  55. # init visualizer
  56. visualizer = VISUALIZERS.build(model.cfg.visualizer)
  57. # the dataset_meta is loaded from the checkpoint and
  58. # then pass to the model in init_detector
  59. visualizer.dataset_meta = model.dataset_meta
  60. # show the results
  61. img = mmcv.imread(args.img)
  62. img = mmcv.imconvert(img, 'bgr', 'rgb')
  63. pt_out_file = None
  64. ts_out_file = None
  65. if args.work_dir is not None:
  66. os.makedirs(args.work_dir, exist_ok=True)
  67. pt_out_file = os.path.join(args.work_dir, 'pytorch_result.png')
  68. ts_out_file = os.path.join(args.work_dir, 'torchserve_result.png')
  69. visualizer.add_datasample(
  70. 'pytorch_result',
  71. img.copy(),
  72. data_sample=pytorch_results,
  73. draw_gt=False,
  74. out_file=pt_out_file,
  75. show=True,
  76. wait_time=0)
  77. url = 'http://' + args.inference_addr + '/predictions/' + args.model_name
  78. with open(args.img, 'rb') as image:
  79. response = requests.post(url, image)
  80. metainfo = pytorch_results.pred_instances.metainfo
  81. ts_results = align_ts_output(response.json(), metainfo, args.device)
  82. visualizer.add_datasample(
  83. 'torchserve_result',
  84. img,
  85. data_sample=ts_results,
  86. draw_gt=False,
  87. out_file=ts_out_file,
  88. show=True,
  89. wait_time=0)
  90. assert torch.allclose(pytorch_results.pred_instances.bboxes,
  91. ts_results.pred_instances.bboxes)
  92. assert torch.allclose(pytorch_results.pred_instances.labels,
  93. ts_results.pred_instances.labels)
  94. assert torch.allclose(pytorch_results.pred_instances.scores,
  95. ts_results.pred_instances.scores)
  96. if __name__ == '__main__':
  97. args = parse_args()
  98. main(args)