mmdetection.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import io
  3. import json
  4. import logging
  5. import os
  6. from urllib.parse import urlparse
  7. import boto3
  8. from botocore.exceptions import ClientError
  9. from label_studio_ml.model import LabelStudioMLBase
  10. from label_studio_ml.utils import (DATA_UNDEFINED_NAME, get_image_size,
  11. get_single_tag_keys)
  12. from label_studio_tools.core.utils.io import get_data_dir
  13. from mmdet.apis import inference_detector, init_detector
  14. logger = logging.getLogger(__name__)
  15. class MMDetection(LabelStudioMLBase):
  16. """Object detector based on https://github.com/open-mmlab/mmdetection."""
  17. def __init__(self,
  18. config_file=None,
  19. checkpoint_file=None,
  20. image_dir=None,
  21. labels_file=None,
  22. score_threshold=0.5,
  23. device='cpu',
  24. **kwargs):
  25. super(MMDetection, self).__init__(**kwargs)
  26. config_file = config_file or os.environ['config_file']
  27. checkpoint_file = checkpoint_file or os.environ['checkpoint_file']
  28. self.config_file = config_file
  29. self.checkpoint_file = checkpoint_file
  30. self.labels_file = labels_file
  31. # default Label Studio image upload folder
  32. upload_dir = os.path.join(get_data_dir(), 'media', 'upload')
  33. self.image_dir = image_dir or upload_dir
  34. logger.debug(
  35. f'{self.__class__.__name__} reads images from {self.image_dir}')
  36. if self.labels_file and os.path.exists(self.labels_file):
  37. self.label_map = json_load(self.labels_file)
  38. else:
  39. self.label_map = {}
  40. self.from_name, self.to_name, self.value, self.labels_in_config = get_single_tag_keys( # noqa E501
  41. self.parsed_label_config, 'RectangleLabels', 'Image')
  42. schema = list(self.parsed_label_config.values())[0]
  43. self.labels_in_config = set(self.labels_in_config)
  44. # Collect label maps from `predicted_values="airplane,car"` attribute in <Label> tag # noqa E501
  45. self.labels_attrs = schema.get('labels_attrs')
  46. if self.labels_attrs:
  47. for label_name, label_attrs in self.labels_attrs.items():
  48. for predicted_value in label_attrs.get('predicted_values',
  49. '').split(','):
  50. self.label_map[predicted_value] = label_name
  51. print('Load new model from: ', config_file, checkpoint_file)
  52. self.model = init_detector(config_file, checkpoint_file, device=device)
  53. self.score_thresh = score_threshold
  54. def _get_image_url(self, task):
  55. image_url = task['data'].get(
  56. self.value) or task['data'].get(DATA_UNDEFINED_NAME)
  57. if image_url.startswith('s3://'):
  58. # presign s3 url
  59. r = urlparse(image_url, allow_fragments=False)
  60. bucket_name = r.netloc
  61. key = r.path.lstrip('/')
  62. client = boto3.client('s3')
  63. try:
  64. image_url = client.generate_presigned_url(
  65. ClientMethod='get_object',
  66. Params={
  67. 'Bucket': bucket_name,
  68. 'Key': key
  69. })
  70. except ClientError as exc:
  71. logger.warning(
  72. f'Can\'t generate presigned URL for {image_url}. Reason: {exc}' # noqa E501
  73. )
  74. return image_url
  75. def predict(self, tasks, **kwargs):
  76. assert len(tasks) == 1
  77. task = tasks[0]
  78. image_url = self._get_image_url(task)
  79. image_path = self.get_local_path(image_url)
  80. model_results = inference_detector(self.model,
  81. image_path).pred_instances
  82. results = []
  83. all_scores = []
  84. img_width, img_height = get_image_size(image_path)
  85. print(f'>>> model_results: {model_results}')
  86. print(f'>>> label_map {self.label_map}')
  87. print(f'>>> self.model.dataset_meta: {self.model.dataset_meta}')
  88. classes = self.model.dataset_meta.get('classes')
  89. print(f'Classes >>> {classes}')
  90. for item in model_results:
  91. print(f'item >>>>> {item}')
  92. bboxes, label, scores = item['bboxes'], item['labels'], item[
  93. 'scores']
  94. score = float(scores[-1])
  95. if score < self.score_thresh:
  96. continue
  97. print(f'bboxes >>>>> {bboxes}')
  98. print(f'label >>>>> {label}')
  99. output_label = classes[list(self.label_map.get(label, label))[0]]
  100. print(f'>>> output_label: {output_label}')
  101. if output_label not in self.labels_in_config:
  102. print(output_label + ' label not found in project config.')
  103. continue
  104. for bbox in bboxes:
  105. bbox = list(bbox)
  106. if not bbox:
  107. continue
  108. x, y, xmax, ymax = bbox[:4]
  109. results.append({
  110. 'from_name': self.from_name,
  111. 'to_name': self.to_name,
  112. 'type': 'rectanglelabels',
  113. 'value': {
  114. 'rectanglelabels': [output_label],
  115. 'x': float(x) / img_width * 100,
  116. 'y': float(y) / img_height * 100,
  117. 'width': (float(xmax) - float(x)) / img_width * 100,
  118. 'height': (float(ymax) - float(y)) / img_height * 100
  119. },
  120. 'score': score
  121. })
  122. all_scores.append(score)
  123. avg_score = sum(all_scores) / max(len(all_scores), 1)
  124. print(f'>>> RESULTS: {results}')
  125. return [{'result': results, 'score': avg_score}]
  126. def json_load(file, int_keys=False):
  127. with io.open(file, encoding='utf8') as f:
  128. data = json.load(f)
  129. if int_keys:
  130. return {int(k): v for k, v in data.items()}
  131. else:
  132. return data