mmdet_handler.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import base64
  3. import os
  4. import mmcv
  5. import numpy as np
  6. import torch
  7. from ts.torch_handler.base_handler import BaseHandler
  8. from mmdet.apis import inference_detector, init_detector
  9. class MMdetHandler(BaseHandler):
  10. threshold = 0.5
  11. def initialize(self, context):
  12. properties = context.system_properties
  13. self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
  14. self.device = torch.device(self.map_location + ':' +
  15. str(properties.get('gpu_id')) if torch.cuda.
  16. is_available() else self.map_location)
  17. self.manifest = context.manifest
  18. model_dir = properties.get('model_dir')
  19. serialized_file = self.manifest['model']['serializedFile']
  20. checkpoint = os.path.join(model_dir, serialized_file)
  21. self.config_file = os.path.join(model_dir, 'config.py')
  22. self.model = init_detector(self.config_file, checkpoint, self.device)
  23. self.initialized = True
  24. def preprocess(self, data):
  25. images = []
  26. for row in data:
  27. image = row.get('data') or row.get('body')
  28. if isinstance(image, str):
  29. image = base64.b64decode(image)
  30. image = mmcv.imfrombytes(image)
  31. images.append(image)
  32. return images
  33. def inference(self, data, *args, **kwargs):
  34. results = inference_detector(self.model, data)
  35. return results
  36. def postprocess(self, data):
  37. # Format output following the example ObjectDetectionHandler format
  38. output = []
  39. for data_sample in data:
  40. pred_instances = data_sample.pred_instances
  41. bboxes = pred_instances.bboxes.cpu().numpy().astype(
  42. np.float32).tolist()
  43. labels = pred_instances.labels.cpu().numpy().astype(
  44. np.int32).tolist()
  45. scores = pred_instances.scores.cpu().numpy().astype(
  46. np.float32).tolist()
  47. preds = []
  48. for idx in range(len(labels)):
  49. cls_score, bbox, cls_label = scores[idx], bboxes[idx], labels[
  50. idx]
  51. if cls_score >= self.threshold:
  52. class_name = self.model.dataset_meta['classes'][cls_label]
  53. result = dict(
  54. class_label=cls_label,
  55. class_name=class_name,
  56. bbox=bbox,
  57. score=cls_score)
  58. preds.append(result)
  59. output.append(preds)
  60. return output