coco_metric.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import datetime
  3. import itertools
  4. import os.path as osp
  5. import tempfile
  6. from collections import OrderedDict
  7. from typing import Dict, List, Optional, Sequence, Union
  8. import numpy as np
  9. import torch
  10. from mmengine.evaluator import BaseMetric
  11. from mmengine.fileio import dump, get_local_path, load
  12. from mmengine.logging import MMLogger
  13. from terminaltables import AsciiTable
  14. from mmdet.datasets.api_wrappers import COCO, COCOeval
  15. from mmdet.registry import METRICS
  16. from mmdet.structures.mask import encode_mask_results
  17. from ..functional import eval_recalls
  18. @METRICS.register_module()
  19. class CocoMetric(BaseMetric):
  20. """COCO evaluation metric.
  21. Evaluate AR, AP, and mAP for detection tasks including proposal/box
  22. detection and instance segmentation. Please refer to
  23. https://cocodataset.org/#detection-eval for more details.
  24. Args:
  25. ann_file (str, optional): Path to the coco format annotation file.
  26. If not specified, ground truth annotations from the dataset will
  27. be converted to coco format. Defaults to None.
  28. metric (str | List[str]): Metrics to be evaluated. Valid metrics
  29. include 'bbox', 'segm', 'proposal', and 'proposal_fast'.
  30. Defaults to 'bbox'.
  31. classwise (bool): Whether to evaluate the metric class-wise.
  32. Defaults to False.
  33. proposal_nums (Sequence[int]): Numbers of proposals to be evaluated.
  34. Defaults to (100, 300, 1000).
  35. iou_thrs (float | List[float], optional): IoU threshold to compute AP
  36. and AR. If not specified, IoUs from 0.5 to 0.95 will be used.
  37. Defaults to None.
  38. metric_items (List[str], optional): Metric result names to be
  39. recorded in the evaluation result. Defaults to None.
  40. format_only (bool): Format the output results without perform
  41. evaluation. It is useful when you want to format the result
  42. to a specific format and submit it to the test server.
  43. Defaults to False.
  44. outfile_prefix (str, optional): The prefix of json files. It includes
  45. the file path and the prefix of filename, e.g., "a/b/prefix".
  46. If not specified, a temp file will be created. Defaults to None.
  47. file_client_args (dict, optional): Arguments to instantiate the
  48. corresponding backend in mmdet <= 3.0.0rc6. Defaults to None.
  49. backend_args (dict, optional): Arguments to instantiate the
  50. corresponding backend. Defaults to None.
  51. collect_device (str): Device name used for collecting results from
  52. different ranks during distributed training. Must be 'cpu' or
  53. 'gpu'. Defaults to 'cpu'.
  54. prefix (str, optional): The prefix that will be added in the metric
  55. names to disambiguate homonymous metrics of different evaluators.
  56. If prefix is not provided in the argument, self.default_prefix
  57. will be used instead. Defaults to None.
  58. sort_categories (bool): Whether sort categories in annotations. Only
  59. used for `Objects365V1Dataset`. Defaults to False.
  60. """
  61. default_prefix: Optional[str] = 'coco'
  62. def __init__(self,
  63. ann_file: Optional[str] = None,
  64. metric: Union[str, List[str]] = 'bbox',
  65. classwise: bool = False,
  66. proposal_nums: Sequence[int] = (100, 300, 1000),
  67. iou_thrs: Optional[Union[float, Sequence[float]]] = None,
  68. metric_items: Optional[Sequence[str]] = None,
  69. format_only: bool = False,
  70. outfile_prefix: Optional[str] = None,
  71. file_client_args: dict = None,
  72. backend_args: dict = None,
  73. collect_device: str = 'cpu',
  74. prefix: Optional[str] = None,
  75. sort_categories: bool = False) -> None:
  76. super().__init__(collect_device=collect_device, prefix=prefix)
  77. # coco evaluation metrics
  78. self.metrics = metric if isinstance(metric, list) else [metric]
  79. allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
  80. for metric in self.metrics:
  81. if metric not in allowed_metrics:
  82. raise KeyError(
  83. "metric should be one of 'bbox', 'segm', 'proposal', "
  84. f"'proposal_fast', but got {metric}.")
  85. # do class wise evaluation, default False
  86. self.classwise = classwise
  87. # proposal_nums used to compute recall or precision.
  88. self.proposal_nums = list(proposal_nums)
  89. # iou_thrs used to compute recall or precision.
  90. if iou_thrs is None:
  91. iou_thrs = np.linspace(
  92. .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
  93. self.iou_thrs = iou_thrs
  94. self.metric_items = metric_items
  95. self.format_only = format_only
  96. if self.format_only:
  97. assert outfile_prefix is not None, 'outfile_prefix must be not'
  98. 'None when format_only is True, otherwise the result files will'
  99. 'be saved to a temp directory which will be cleaned up at the end.'
  100. self.outfile_prefix = outfile_prefix
  101. self.backend_args = backend_args
  102. if file_client_args is not None:
  103. raise RuntimeError(
  104. 'The `file_client_args` is deprecated, '
  105. 'please use `backend_args` instead, please refer to'
  106. 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501
  107. )
  108. # if ann_file is not specified,
  109. # initialize coco api with the converted dataset
  110. if ann_file is not None:
  111. with get_local_path(
  112. ann_file, backend_args=self.backend_args) as local_path:
  113. self._coco_api = COCO(local_path)
  114. if sort_categories:
  115. # 'categories' list in objects365_train.json and
  116. # objects365_val.json is inconsistent, need sort
  117. # list(or dict) before get cat_ids.
  118. cats = self._coco_api.cats
  119. sorted_cats = {i: cats[i] for i in sorted(cats)}
  120. self._coco_api.cats = sorted_cats
  121. categories = self._coco_api.dataset['categories']
  122. sorted_categories = sorted(
  123. categories, key=lambda i: i['id'])
  124. self._coco_api.dataset['categories'] = sorted_categories
  125. else:
  126. self._coco_api = None
  127. # handle dataset lazy init
  128. self.cat_ids = None
  129. self.img_ids = None
  130. def fast_eval_recall(self,
  131. results: List[dict],
  132. proposal_nums: Sequence[int],
  133. iou_thrs: Sequence[float],
  134. logger: Optional[MMLogger] = None) -> np.ndarray:
  135. """Evaluate proposal recall with COCO's fast_eval_recall.
  136. Args:
  137. results (List[dict]): Results of the dataset.
  138. proposal_nums (Sequence[int]): Proposal numbers used for
  139. evaluation.
  140. iou_thrs (Sequence[float]): IoU thresholds used for evaluation.
  141. logger (MMLogger, optional): Logger used for logging the recall
  142. summary.
  143. Returns:
  144. np.ndarray: Averaged recall results.
  145. """
  146. gt_bboxes = []
  147. pred_bboxes = [result['bboxes'] for result in results]
  148. for i in range(len(self.img_ids)):
  149. ann_ids = self._coco_api.get_ann_ids(img_ids=self.img_ids[i])
  150. ann_info = self._coco_api.load_anns(ann_ids)
  151. if len(ann_info) == 0:
  152. gt_bboxes.append(np.zeros((0, 4)))
  153. continue
  154. bboxes = []
  155. for ann in ann_info:
  156. if ann.get('ignore', False) or ann['iscrowd']:
  157. continue
  158. x1, y1, w, h = ann['bbox']
  159. bboxes.append([x1, y1, x1 + w, y1 + h])
  160. bboxes = np.array(bboxes, dtype=np.float32)
  161. if bboxes.shape[0] == 0:
  162. bboxes = np.zeros((0, 4))
  163. gt_bboxes.append(bboxes)
  164. recalls = eval_recalls(
  165. gt_bboxes, pred_bboxes, proposal_nums, iou_thrs, logger=logger)
  166. ar = recalls.mean(axis=1)
  167. return ar
  168. def xyxy2xywh(self, bbox: np.ndarray) -> list:
  169. """Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO
  170. evaluation.
  171. Args:
  172. bbox (numpy.ndarray): The bounding boxes, shape (4, ), in
  173. ``xyxy`` order.
  174. Returns:
  175. list[float]: The converted bounding boxes, in ``xywh`` order.
  176. """
  177. _bbox: List = bbox.tolist()
  178. return [
  179. _bbox[0],
  180. _bbox[1],
  181. _bbox[2] - _bbox[0],
  182. _bbox[3] - _bbox[1],
  183. ]
  184. def results2json(self, results: Sequence[dict],
  185. outfile_prefix: str) -> dict:
  186. """Dump the detection results to a COCO style json file.
  187. There are 3 types of results: proposals, bbox predictions, mask
  188. predictions, and they have different data types. This method will
  189. automatically recognize the type, and dump them to json files.
  190. Args:
  191. results (Sequence[dict]): Testing results of the
  192. dataset.
  193. outfile_prefix (str): The filename prefix of the json files. If the
  194. prefix is "somepath/xxx", the json files will be named
  195. "somepath/xxx.bbox.json", "somepath/xxx.segm.json",
  196. "somepath/xxx.proposal.json".
  197. Returns:
  198. dict: Possible keys are "bbox", "segm", "proposal", and
  199. values are corresponding filenames.
  200. """
  201. bbox_json_results = []
  202. segm_json_results = [] if 'masks' in results[0] else None
  203. for idx, result in enumerate(results):
  204. image_id = result.get('img_id', idx)
  205. labels = result['labels']
  206. bboxes = result['bboxes']
  207. scores = result['scores']
  208. # bbox results
  209. for i, label in enumerate(labels):
  210. data = dict()
  211. data['image_id'] = image_id
  212. data['bbox'] = self.xyxy2xywh(bboxes[i])
  213. data['score'] = float(scores[i])
  214. data['category_id'] = self.cat_ids[label]
  215. bbox_json_results.append(data)
  216. if segm_json_results is None:
  217. continue
  218. # segm results
  219. masks = result['masks']
  220. mask_scores = result.get('mask_scores', scores)
  221. for i, label in enumerate(labels):
  222. data = dict()
  223. data['image_id'] = image_id
  224. data['bbox'] = self.xyxy2xywh(bboxes[i])
  225. data['score'] = float(mask_scores[i])
  226. data['category_id'] = self.cat_ids[label]
  227. if isinstance(masks[i]['counts'], bytes):
  228. masks[i]['counts'] = masks[i]['counts'].decode()
  229. data['segmentation'] = masks[i]
  230. segm_json_results.append(data)
  231. result_files = dict()
  232. result_files['bbox'] = f'{outfile_prefix}.bbox.json'
  233. result_files['proposal'] = f'{outfile_prefix}.bbox.json'
  234. dump(bbox_json_results, result_files['bbox'])
  235. if segm_json_results is not None:
  236. result_files['segm'] = f'{outfile_prefix}.segm.json'
  237. dump(segm_json_results, result_files['segm'])
  238. return result_files
  239. def gt_to_coco_json(self, gt_dicts: Sequence[dict],
  240. outfile_prefix: str) -> str:
  241. """Convert ground truth to coco format json file.
  242. Args:
  243. gt_dicts (Sequence[dict]): Ground truth of the dataset.
  244. outfile_prefix (str): The filename prefix of the json files. If the
  245. prefix is "somepath/xxx", the json file will be named
  246. "somepath/xxx.gt.json".
  247. Returns:
  248. str: The filename of the json file.
  249. """
  250. categories = [
  251. dict(id=id, name=name)
  252. for id, name in enumerate(self.dataset_meta['classes'])
  253. ]
  254. image_infos = []
  255. annotations = []
  256. for idx, gt_dict in enumerate(gt_dicts):
  257. img_id = gt_dict.get('img_id', idx)
  258. image_info = dict(
  259. id=img_id,
  260. width=gt_dict['width'],
  261. height=gt_dict['height'],
  262. file_name='')
  263. image_infos.append(image_info)
  264. for ann in gt_dict['anns']:
  265. label = ann['bbox_label']
  266. bbox = ann['bbox']
  267. coco_bbox = [
  268. bbox[0],
  269. bbox[1],
  270. bbox[2] - bbox[0],
  271. bbox[3] - bbox[1],
  272. ]
  273. annotation = dict(
  274. id=len(annotations) +
  275. 1, # coco api requires id starts with 1
  276. image_id=img_id,
  277. bbox=coco_bbox,
  278. iscrowd=ann.get('ignore_flag', 0),
  279. category_id=int(label),
  280. area=coco_bbox[2] * coco_bbox[3])
  281. if ann.get('mask', None):
  282. mask = ann['mask']
  283. # area = mask_util.area(mask)
  284. if isinstance(mask, dict) and isinstance(
  285. mask['counts'], bytes):
  286. mask['counts'] = mask['counts'].decode()
  287. annotation['segmentation'] = mask
  288. # annotation['area'] = float(area)
  289. annotations.append(annotation)
  290. info = dict(
  291. date_created=str(datetime.datetime.now()),
  292. description='Coco json file converted by mmdet CocoMetric.')
  293. coco_json = dict(
  294. info=info,
  295. images=image_infos,
  296. categories=categories,
  297. licenses=None,
  298. )
  299. if len(annotations) > 0:
  300. coco_json['annotations'] = annotations
  301. converted_json_path = f'{outfile_prefix}.gt.json'
  302. dump(coco_json, converted_json_path)
  303. return converted_json_path
  304. # TODO: data_batch is no longer needed, consider adjusting the
  305. # parameter position
  306. def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
  307. """Process one batch of data samples and predictions. The processed
  308. results should be stored in ``self.results``, which will be used to
  309. compute the metrics when all batches have been processed.
  310. Args:
  311. data_batch (dict): A batch of data from the dataloader.
  312. data_samples (Sequence[dict]): A batch of data samples that
  313. contain annotations and predictions.
  314. """
  315. for data_sample in data_samples:
  316. result = dict()
  317. pred = data_sample['pred_instances']
  318. result['img_id'] = data_sample['img_id']
  319. result['bboxes'] = pred['bboxes'].cpu().numpy()
  320. result['scores'] = pred['scores'].cpu().numpy()
  321. result['labels'] = pred['labels'].cpu().numpy()
  322. # encode mask to RLE
  323. if 'masks' in pred:
  324. result['masks'] = encode_mask_results(
  325. pred['masks'].detach().cpu().numpy()) if isinstance(
  326. pred['masks'], torch.Tensor) else pred['masks']
  327. # some detectors use different scores for bbox and mask
  328. if 'mask_scores' in pred:
  329. result['mask_scores'] = pred['mask_scores'].cpu().numpy()
  330. # parse gt
  331. gt = dict()
  332. gt['width'] = data_sample['ori_shape'][1]
  333. gt['height'] = data_sample['ori_shape'][0]
  334. gt['img_id'] = data_sample['img_id']
  335. if self._coco_api is None:
  336. # TODO: Need to refactor to support LoadAnnotations
  337. assert 'instances' in data_sample, \
  338. 'ground truth is required for evaluation when ' \
  339. '`ann_file` is not provided'
  340. gt['anns'] = data_sample['instances']
  341. # add converted result to the results list
  342. self.results.append((gt, result))
  343. def compute_metrics(self, results: list) -> Dict[str, float]:
  344. """Compute the metrics from processed results.
  345. Args:
  346. results (list): The processed results of each batch.
  347. Returns:
  348. Dict[str, float]: The computed metrics. The keys are the names of
  349. the metrics, and the values are corresponding results.
  350. """
  351. logger: MMLogger = MMLogger.get_current_instance()
  352. # split gt and prediction list
  353. gts, preds = zip(*results)
  354. tmp_dir = None
  355. if self.outfile_prefix is None:
  356. tmp_dir = tempfile.TemporaryDirectory()
  357. outfile_prefix = osp.join(tmp_dir.name, 'results')
  358. else:
  359. outfile_prefix = self.outfile_prefix
  360. if self._coco_api is None:
  361. # use converted gt json file to initialize coco api
  362. logger.info('Converting ground truth to coco format...')
  363. coco_json_path = self.gt_to_coco_json(
  364. gt_dicts=gts, outfile_prefix=outfile_prefix)
  365. self._coco_api = COCO(coco_json_path)
  366. # handle lazy init
  367. if self.cat_ids is None:
  368. self.cat_ids = self._coco_api.get_cat_ids(
  369. cat_names=self.dataset_meta['classes'])
  370. if self.img_ids is None:
  371. self.img_ids = self._coco_api.get_img_ids()
  372. # convert predictions to coco format and dump to json file
  373. result_files = self.results2json(preds, outfile_prefix)
  374. eval_results = OrderedDict()
  375. if self.format_only:
  376. logger.info('results are saved in '
  377. f'{osp.dirname(outfile_prefix)}')
  378. return eval_results
  379. for metric in self.metrics:
  380. logger.info(f'Evaluating {metric}...')
  381. # TODO: May refactor fast_eval_recall to an independent metric?
  382. # fast eval recall
  383. if metric == 'proposal_fast':
  384. ar = self.fast_eval_recall(
  385. preds, self.proposal_nums, self.iou_thrs, logger=logger)
  386. log_msg = []
  387. for i, num in enumerate(self.proposal_nums):
  388. eval_results[f'AR@{num}'] = ar[i]
  389. log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}')
  390. log_msg = ''.join(log_msg)
  391. logger.info(log_msg)
  392. continue
  393. # evaluate proposal, bbox and segm
  394. iou_type = 'bbox' if metric == 'proposal' else metric
  395. if metric not in result_files:
  396. raise KeyError(f'{metric} is not in results')
  397. try:
  398. predictions = load(result_files[metric])
  399. if iou_type == 'segm':
  400. # Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa
  401. # When evaluating mask AP, if the results contain bbox,
  402. # cocoapi will use the box area instead of the mask area
  403. # for calculating the instance area. Though the overall AP
  404. # is not affected, this leads to different
  405. # small/medium/large mask AP results.
  406. for x in predictions:
  407. x.pop('bbox')
  408. coco_dt = self._coco_api.loadRes(predictions)
  409. except IndexError:
  410. logger.error(
  411. 'The testing results of the whole dataset is empty.')
  412. break
  413. coco_eval = COCOeval(self._coco_api, coco_dt, iou_type)
  414. coco_eval.params.catIds = self.cat_ids
  415. coco_eval.params.imgIds = self.img_ids
  416. coco_eval.params.maxDets = list(self.proposal_nums)
  417. coco_eval.params.iouThrs = self.iou_thrs
  418. # mapping of cocoEval.stats
  419. coco_metric_names = {
  420. 'mAP': 0,
  421. 'mAP_50': 1,
  422. 'mAP_75': 2,
  423. 'mAP_s': 3,
  424. 'mAP_m': 4,
  425. 'mAP_l': 5,
  426. 'AR@100': 6,
  427. 'AR@300': 7,
  428. 'AR@1000': 8,
  429. 'AR_s@1000': 9,
  430. 'AR_m@1000': 10,
  431. 'AR_l@1000': 11
  432. }
  433. metric_items = self.metric_items
  434. if metric_items is not None:
  435. for metric_item in metric_items:
  436. if metric_item not in coco_metric_names:
  437. raise KeyError(
  438. f'metric item "{metric_item}" is not supported')
  439. if metric == 'proposal':
  440. coco_eval.params.useCats = 0
  441. coco_eval.evaluate()
  442. coco_eval.accumulate()
  443. coco_eval.summarize()
  444. if metric_items is None:
  445. metric_items = [
  446. 'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000',
  447. 'AR_m@1000', 'AR_l@1000'
  448. ]
  449. for item in metric_items:
  450. val = float(
  451. f'{coco_eval.stats[coco_metric_names[item]]:.3f}')
  452. eval_results[item] = val
  453. else:
  454. coco_eval.evaluate()
  455. coco_eval.accumulate()
  456. coco_eval.summarize()
  457. if self.classwise: # Compute per-category AP
  458. # Compute per-category AP
  459. # from https://github.com/facebookresearch/detectron2/
  460. precisions = coco_eval.eval['precision']
  461. # precision: (iou, recall, cls, area range, max dets)
  462. assert len(self.cat_ids) == precisions.shape[2]
  463. results_per_category = []
  464. for idx, cat_id in enumerate(self.cat_ids):
  465. t = []
  466. # area range index 0: all area ranges
  467. # max dets index -1: typically 100 per image
  468. nm = self._coco_api.loadCats(cat_id)[0]
  469. precision = precisions[:, :, idx, 0, -1]
  470. precision = precision[precision > -1]
  471. if precision.size:
  472. ap = np.mean(precision)
  473. else:
  474. ap = float('nan')
  475. t.append(f'{nm["name"]}')
  476. t.append(f'{round(ap, 3)}')
  477. eval_results[f'{nm["name"]}_precision'] = round(ap, 3)
  478. # indexes of IoU @50 and @75
  479. for iou in [0, 5]:
  480. precision = precisions[iou, :, idx, 0, -1]
  481. precision = precision[precision > -1]
  482. if precision.size:
  483. ap = np.mean(precision)
  484. else:
  485. ap = float('nan')
  486. t.append(f'{round(ap, 3)}')
  487. # indexes of area of small, median and large
  488. for area in [1, 2, 3]:
  489. precision = precisions[:, :, idx, area, -1]
  490. precision = precision[precision > -1]
  491. if precision.size:
  492. ap = np.mean(precision)
  493. else:
  494. ap = float('nan')
  495. t.append(f'{round(ap, 3)}')
  496. results_per_category.append(tuple(t))
  497. num_columns = len(results_per_category[0])
  498. results_flatten = list(
  499. itertools.chain(*results_per_category))
  500. headers = [
  501. 'category', 'mAP', 'mAP_50', 'mAP_75', 'mAP_s',
  502. 'mAP_m', 'mAP_l'
  503. ]
  504. results_2d = itertools.zip_longest(*[
  505. results_flatten[i::num_columns]
  506. for i in range(num_columns)
  507. ])
  508. table_data = [headers]
  509. table_data += [result for result in results_2d]
  510. table = AsciiTable(table_data)
  511. logger.info('\n' + table.table)
  512. if metric_items is None:
  513. metric_items = [
  514. 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'
  515. ]
  516. for metric_item in metric_items:
  517. key = f'{metric}_{metric_item}'
  518. val = coco_eval.stats[coco_metric_names[metric_item]]
  519. eval_results[key] = float(f'{round(val, 3)}')
  520. ap = coco_eval.stats[:6]
  521. logger.info(f'{metric}_mAP_copypaste: {ap[0]:.3f} '
  522. f'{ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
  523. f'{ap[4]:.3f} {ap[5]:.3f}')
  524. if tmp_dir is not None:
  525. tmp_dir.cleanup()
  526. return eval_results