123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from multiprocessing import Pool
- import numpy as np
- from mmengine.logging import print_log
- from mmengine.utils import is_str
- from terminaltables import AsciiTable
- from .bbox_overlaps import bbox_overlaps
- from .class_names import get_classes
- def average_precision(recalls, precisions, mode='area'):
- """Calculate average precision (for single or multiple scales).
- Args:
- recalls (ndarray): shape (num_scales, num_dets) or (num_dets, )
- precisions (ndarray): shape (num_scales, num_dets) or (num_dets, )
- mode (str): 'area' or '11points', 'area' means calculating the area
- under precision-recall curve, '11points' means calculating
- the average precision of recalls at [0, 0.1, ..., 1]
- Returns:
- float or ndarray: calculated average precision
- """
- no_scale = False
- if recalls.ndim == 1:
- no_scale = True
- recalls = recalls[np.newaxis, :]
- precisions = precisions[np.newaxis, :]
- assert recalls.shape == precisions.shape and recalls.ndim == 2
- num_scales = recalls.shape[0]
- ap = np.zeros(num_scales, dtype=np.float32)
- if mode == 'area':
- zeros = np.zeros((num_scales, 1), dtype=recalls.dtype)
- ones = np.ones((num_scales, 1), dtype=recalls.dtype)
- mrec = np.hstack((zeros, recalls, ones))
- mpre = np.hstack((zeros, precisions, zeros))
- for i in range(mpre.shape[1] - 1, 0, -1):
- mpre[:, i - 1] = np.maximum(mpre[:, i - 1], mpre[:, i])
- for i in range(num_scales):
- ind = np.where(mrec[i, 1:] != mrec[i, :-1])[0]
- ap[i] = np.sum(
- (mrec[i, ind + 1] - mrec[i, ind]) * mpre[i, ind + 1])
- elif mode == '11points':
- for i in range(num_scales):
- for thr in np.arange(0, 1 + 1e-3, 0.1):
- precs = precisions[i, recalls[i, :] >= thr]
- prec = precs.max() if precs.size > 0 else 0
- ap[i] += prec
- ap /= 11
- else:
- raise ValueError(
- 'Unrecognized mode, only "area" and "11points" are supported')
- if no_scale:
- ap = ap[0]
- return ap
- def tpfp_imagenet(det_bboxes,
- gt_bboxes,
- gt_bboxes_ignore=None,
- default_iou_thr=0.5,
- area_ranges=None,
- use_legacy_coordinate=False,
- **kwargs):
- """Check if detected bboxes are true positive or false positive.
- Args:
- det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5).
- gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4).
- gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image,
- of shape (k, 4). Defaults to None
- default_iou_thr (float): IoU threshold to be considered as matched for
- medium and large bboxes (small ones have special rules).
- Defaults to 0.5.
- area_ranges (list[tuple] | None): Range of bbox areas to be evaluated,
- in the format [(min1, max1), (min2, max2), ...]. Defaults to None.
- use_legacy_coordinate (bool): Whether to use coordinate system in
- mmdet v1.x. which means width, height should be
- calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively.
- Defaults to False.
- Returns:
- tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of
- each array is (num_scales, m).
- """
- if not use_legacy_coordinate:
- extra_length = 0.
- else:
- extra_length = 1.
- # an indicator of ignored gts
- gt_ignore_inds = np.concatenate(
- (np.zeros(gt_bboxes.shape[0],
- dtype=bool), np.ones(gt_bboxes_ignore.shape[0], dtype=bool)))
- # stack gt_bboxes and gt_bboxes_ignore for convenience
- gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore))
- num_dets = det_bboxes.shape[0]
- num_gts = gt_bboxes.shape[0]
- if area_ranges is None:
- area_ranges = [(None, None)]
- num_scales = len(area_ranges)
- # tp and fp are of shape (num_scales, num_gts), each row is tp or fp
- # of a certain scale.
- tp = np.zeros((num_scales, num_dets), dtype=np.float32)
- fp = np.zeros((num_scales, num_dets), dtype=np.float32)
- if gt_bboxes.shape[0] == 0:
- if area_ranges == [(None, None)]:
- fp[...] = 1
- else:
- det_areas = (
- det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * (
- det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length)
- for i, (min_area, max_area) in enumerate(area_ranges):
- fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
- return tp, fp
- ious = bbox_overlaps(
- det_bboxes, gt_bboxes - 1, use_legacy_coordinate=use_legacy_coordinate)
- gt_w = gt_bboxes[:, 2] - gt_bboxes[:, 0] + extra_length
- gt_h = gt_bboxes[:, 3] - gt_bboxes[:, 1] + extra_length
- iou_thrs = np.minimum((gt_w * gt_h) / ((gt_w + 10.0) * (gt_h + 10.0)),
- default_iou_thr)
- # sort all detections by scores in descending order
- sort_inds = np.argsort(-det_bboxes[:, -1])
- for k, (min_area, max_area) in enumerate(area_ranges):
- gt_covered = np.zeros(num_gts, dtype=bool)
- # if no area range is specified, gt_area_ignore is all False
- if min_area is None:
- gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
- else:
- gt_areas = gt_w * gt_h
- gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area)
- for i in sort_inds:
- max_iou = -1
- matched_gt = -1
- # find best overlapped available gt
- for j in range(num_gts):
- # different from PASCAL VOC: allow finding other gts if the
- # best overlapped ones are already matched by other det bboxes
- if gt_covered[j]:
- continue
- elif ious[i, j] >= iou_thrs[j] and ious[i, j] > max_iou:
- max_iou = ious[i, j]
- matched_gt = j
- # there are 4 cases for a det bbox:
- # 1. it matches a gt, tp = 1, fp = 0
- # 2. it matches an ignored gt, tp = 0, fp = 0
- # 3. it matches no gt and within area range, tp = 0, fp = 1
- # 4. it matches no gt but is beyond area range, tp = 0, fp = 0
- if matched_gt >= 0:
- gt_covered[matched_gt] = 1
- if not (gt_ignore_inds[matched_gt]
- or gt_area_ignore[matched_gt]):
- tp[k, i] = 1
- elif min_area is None:
- fp[k, i] = 1
- else:
- bbox = det_bboxes[i, :4]
- area = (bbox[2] - bbox[0] + extra_length) * (
- bbox[3] - bbox[1] + extra_length)
- if area >= min_area and area < max_area:
- fp[k, i] = 1
- return tp, fp
- def tpfp_default(det_bboxes,
- gt_bboxes,
- gt_bboxes_ignore=None,
- iou_thr=0.5,
- area_ranges=None,
- use_legacy_coordinate=False,
- **kwargs):
- """Check if detected bboxes are true positive or false positive.
- Args:
- det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5).
- gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4).
- gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image,
- of shape (k, 4). Defaults to None
- iou_thr (float): IoU threshold to be considered as matched.
- Defaults to 0.5.
- area_ranges (list[tuple] | None): Range of bbox areas to be
- evaluated, in the format [(min1, max1), (min2, max2), ...].
- Defaults to None.
- use_legacy_coordinate (bool): Whether to use coordinate system in
- mmdet v1.x. which means width, height should be
- calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively.
- Defaults to False.
- Returns:
- tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of
- each array is (num_scales, m).
- """
- if not use_legacy_coordinate:
- extra_length = 0.
- else:
- extra_length = 1.
- # an indicator of ignored gts
- gt_ignore_inds = np.concatenate(
- (np.zeros(gt_bboxes.shape[0],
- dtype=bool), np.ones(gt_bboxes_ignore.shape[0], dtype=bool)))
- # stack gt_bboxes and gt_bboxes_ignore for convenience
- gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore))
- num_dets = det_bboxes.shape[0]
- num_gts = gt_bboxes.shape[0]
- if area_ranges is None:
- area_ranges = [(None, None)]
- num_scales = len(area_ranges)
- # tp and fp are of shape (num_scales, num_gts), each row is tp or fp of
- # a certain scale
- tp = np.zeros((num_scales, num_dets), dtype=np.float32)
- fp = np.zeros((num_scales, num_dets), dtype=np.float32)
- # if there is no gt bboxes in this image, then all det bboxes
- # within area range are false positives
- if gt_bboxes.shape[0] == 0:
- if area_ranges == [(None, None)]:
- fp[...] = 1
- else:
- det_areas = (
- det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * (
- det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length)
- for i, (min_area, max_area) in enumerate(area_ranges):
- fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
- return tp, fp
- ious = bbox_overlaps(
- det_bboxes, gt_bboxes, use_legacy_coordinate=use_legacy_coordinate)
- # for each det, the max iou with all gts
- ious_max = ious.max(axis=1)
- # for each det, which gt overlaps most with it
- ious_argmax = ious.argmax(axis=1)
- # sort all dets in descending order by scores
- sort_inds = np.argsort(-det_bboxes[:, -1])
- for k, (min_area, max_area) in enumerate(area_ranges):
- gt_covered = np.zeros(num_gts, dtype=bool)
- # if no area range is specified, gt_area_ignore is all False
- if min_area is None:
- gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
- else:
- gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0] + extra_length) * (
- gt_bboxes[:, 3] - gt_bboxes[:, 1] + extra_length)
- gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area)
- for i in sort_inds:
- if ious_max[i] >= iou_thr:
- matched_gt = ious_argmax[i]
- if not (gt_ignore_inds[matched_gt]
- or gt_area_ignore[matched_gt]):
- if not gt_covered[matched_gt]:
- gt_covered[matched_gt] = True
- tp[k, i] = 1
- else:
- fp[k, i] = 1
- # otherwise ignore this detected bbox, tp = 0, fp = 0
- elif min_area is None:
- fp[k, i] = 1
- else:
- bbox = det_bboxes[i, :4]
- area = (bbox[2] - bbox[0] + extra_length) * (
- bbox[3] - bbox[1] + extra_length)
- if area >= min_area and area < max_area:
- fp[k, i] = 1
- return tp, fp
- def tpfp_openimages(det_bboxes,
- gt_bboxes,
- gt_bboxes_ignore=None,
- iou_thr=0.5,
- area_ranges=None,
- use_legacy_coordinate=False,
- gt_bboxes_group_of=None,
- use_group_of=True,
- ioa_thr=0.5,
- **kwargs):
- """Check if detected bboxes are true positive or false positive.
- Args:
- det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5).
- gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4).
- gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image,
- of shape (k, 4). Defaults to None
- iou_thr (float): IoU threshold to be considered as matched.
- Defaults to 0.5.
- area_ranges (list[tuple] | None): Range of bbox areas to be
- evaluated, in the format [(min1, max1), (min2, max2), ...].
- Defaults to None.
- use_legacy_coordinate (bool): Whether to use coordinate system in
- mmdet v1.x. which means width, height should be
- calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively.
- Defaults to False.
- gt_bboxes_group_of (ndarray): GT group_of of this image, of shape
- (k, 1). Defaults to None
- use_group_of (bool): Whether to use group of when calculate TP and FP,
- which only used in OpenImages evaluation. Defaults to True.
- ioa_thr (float | None): IoA threshold to be considered as matched,
- which only used in OpenImages evaluation. Defaults to 0.5.
- Returns:
- tuple[np.ndarray]: Returns a tuple (tp, fp, det_bboxes), where
- (tp, fp) whose elements are 0 and 1. The shape of each array is
- (num_scales, m). (det_bboxes) whose will filter those are not
- matched by group of gts when processing Open Images evaluation.
- The shape is (num_scales, m).
- """
- if not use_legacy_coordinate:
- extra_length = 0.
- else:
- extra_length = 1.
- # an indicator of ignored gts
- gt_ignore_inds = np.concatenate(
- (np.zeros(gt_bboxes.shape[0],
- dtype=bool), np.ones(gt_bboxes_ignore.shape[0], dtype=bool)))
- # stack gt_bboxes and gt_bboxes_ignore for convenience
- gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore))
- num_dets = det_bboxes.shape[0]
- num_gts = gt_bboxes.shape[0]
- if area_ranges is None:
- area_ranges = [(None, None)]
- num_scales = len(area_ranges)
- # tp and fp are of shape (num_scales, num_gts), each row is tp or fp of
- # a certain scale
- tp = np.zeros((num_scales, num_dets), dtype=np.float32)
- fp = np.zeros((num_scales, num_dets), dtype=np.float32)
- # if there is no gt bboxes in this image, then all det bboxes
- # within area range are false positives
- if gt_bboxes.shape[0] == 0:
- if area_ranges == [(None, None)]:
- fp[...] = 1
- else:
- det_areas = (
- det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * (
- det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length)
- for i, (min_area, max_area) in enumerate(area_ranges):
- fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
- return tp, fp, det_bboxes
- if gt_bboxes_group_of is not None and use_group_of:
- # if handle group-of boxes, divided gt boxes into two parts:
- # non-group-of and group-of.Then calculate ious and ioas through
- # non-group-of group-of gts respectively. This only used in
- # OpenImages evaluation.
- assert gt_bboxes_group_of.shape[0] == gt_bboxes.shape[0]
- non_group_gt_bboxes = gt_bboxes[~gt_bboxes_group_of]
- group_gt_bboxes = gt_bboxes[gt_bboxes_group_of]
- num_gts_group = group_gt_bboxes.shape[0]
- ious = bbox_overlaps(det_bboxes, non_group_gt_bboxes)
- ioas = bbox_overlaps(det_bboxes, group_gt_bboxes, mode='iof')
- else:
- # if not consider group-of boxes, only calculate ious through gt boxes
- ious = bbox_overlaps(
- det_bboxes, gt_bboxes, use_legacy_coordinate=use_legacy_coordinate)
- ioas = None
- if ious.shape[1] > 0:
- # for each det, the max iou with all gts
- ious_max = ious.max(axis=1)
- # for each det, which gt overlaps most with it
- ious_argmax = ious.argmax(axis=1)
- # sort all dets in descending order by scores
- sort_inds = np.argsort(-det_bboxes[:, -1])
- for k, (min_area, max_area) in enumerate(area_ranges):
- gt_covered = np.zeros(num_gts, dtype=bool)
- # if no area range is specified, gt_area_ignore is all False
- if min_area is None:
- gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
- else:
- gt_areas = (
- gt_bboxes[:, 2] - gt_bboxes[:, 0] + extra_length) * (
- gt_bboxes[:, 3] - gt_bboxes[:, 1] + extra_length)
- gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area)
- for i in sort_inds:
- if ious_max[i] >= iou_thr:
- matched_gt = ious_argmax[i]
- if not (gt_ignore_inds[matched_gt]
- or gt_area_ignore[matched_gt]):
- if not gt_covered[matched_gt]:
- gt_covered[matched_gt] = True
- tp[k, i] = 1
- else:
- fp[k, i] = 1
- # otherwise ignore this detected bbox, tp = 0, fp = 0
- elif min_area is None:
- fp[k, i] = 1
- else:
- bbox = det_bboxes[i, :4]
- area = (bbox[2] - bbox[0] + extra_length) * (
- bbox[3] - bbox[1] + extra_length)
- if area >= min_area and area < max_area:
- fp[k, i] = 1
- else:
- # if there is no no-group-of gt bboxes in this image,
- # then all det bboxes within area range are false positives.
- # Only used in OpenImages evaluation.
- if area_ranges == [(None, None)]:
- fp[...] = 1
- else:
- det_areas = (
- det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * (
- det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length)
- for i, (min_area, max_area) in enumerate(area_ranges):
- fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
- if ioas is None or ioas.shape[1] <= 0:
- return tp, fp, det_bboxes
- else:
- # The evaluation of group-of TP and FP are done in two stages:
- # 1. All detections are first matched to non group-of boxes; true
- # positives are determined.
- # 2. Detections that are determined as false positives are matched
- # against group-of boxes and calculated group-of TP and FP.
- # Only used in OpenImages evaluation.
- det_bboxes_group = np.zeros(
- (num_scales, ioas.shape[1], det_bboxes.shape[1]), dtype=float)
- match_group_of = np.zeros((num_scales, num_dets), dtype=bool)
- tp_group = np.zeros((num_scales, num_gts_group), dtype=np.float32)
- ioas_max = ioas.max(axis=1)
- # for each det, which gt overlaps most with it
- ioas_argmax = ioas.argmax(axis=1)
- # sort all dets in descending order by scores
- sort_inds = np.argsort(-det_bboxes[:, -1])
- for k, (min_area, max_area) in enumerate(area_ranges):
- box_is_covered = tp[k]
- # if no area range is specified, gt_area_ignore is all False
- if min_area is None:
- gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
- else:
- gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (
- gt_bboxes[:, 3] - gt_bboxes[:, 1])
- gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area)
- for i in sort_inds:
- matched_gt = ioas_argmax[i]
- if not box_is_covered[i]:
- if ioas_max[i] >= ioa_thr:
- if not (gt_ignore_inds[matched_gt]
- or gt_area_ignore[matched_gt]):
- if not tp_group[k, matched_gt]:
- tp_group[k, matched_gt] = 1
- match_group_of[k, i] = True
- else:
- match_group_of[k, i] = True
- if det_bboxes_group[k, matched_gt, -1] < \
- det_bboxes[i, -1]:
- det_bboxes_group[k, matched_gt] = \
- det_bboxes[i]
- fp_group = (tp_group <= 0).astype(float)
- tps = []
- fps = []
- # concatenate tp, fp, and det-boxes which not matched group of
- # gt boxes and tp_group, fp_group, and det_bboxes_group which
- # matched group of boxes respectively.
- for i in range(num_scales):
- tps.append(
- np.concatenate((tp[i][~match_group_of[i]], tp_group[i])))
- fps.append(
- np.concatenate((fp[i][~match_group_of[i]], fp_group[i])))
- det_bboxes = np.concatenate(
- (det_bboxes[~match_group_of[i]], det_bboxes_group[i]))
- tp = np.vstack(tps)
- fp = np.vstack(fps)
- return tp, fp, det_bboxes
- def get_cls_results(det_results, annotations, class_id):
- """Get det results and gt information of a certain class.
- Args:
- det_results (list[list]): Same as `eval_map()`.
- annotations (list[dict]): Same as `eval_map()`.
- class_id (int): ID of a specific class.
- Returns:
- tuple[list[np.ndarray]]: detected bboxes, gt bboxes, ignored gt bboxes
- """
- cls_dets = [img_res[class_id] for img_res in det_results]
- cls_gts = []
- cls_gts_ignore = []
- for ann in annotations:
- gt_inds = ann['labels'] == class_id
- cls_gts.append(ann['bboxes'][gt_inds, :])
- if ann.get('labels_ignore', None) is not None:
- ignore_inds = ann['labels_ignore'] == class_id
- cls_gts_ignore.append(ann['bboxes_ignore'][ignore_inds, :])
- else:
- cls_gts_ignore.append(np.empty((0, 4), dtype=np.float32))
- return cls_dets, cls_gts, cls_gts_ignore
- def get_cls_group_ofs(annotations, class_id):
- """Get `gt_group_of` of a certain class, which is used in Open Images.
- Args:
- annotations (list[dict]): Same as `eval_map()`.
- class_id (int): ID of a specific class.
- Returns:
- list[np.ndarray]: `gt_group_of` of a certain class.
- """
- gt_group_ofs = []
- for ann in annotations:
- gt_inds = ann['labels'] == class_id
- if ann.get('gt_is_group_ofs', None) is not None:
- gt_group_ofs.append(ann['gt_is_group_ofs'][gt_inds])
- else:
- gt_group_ofs.append(np.empty((0, 1), dtype=bool))
- return gt_group_ofs
- def eval_map(det_results,
- annotations,
- scale_ranges=None,
- iou_thr=0.5,
- ioa_thr=None,
- dataset=None,
- logger=None,
- tpfp_fn=None,
- nproc=4,
- use_legacy_coordinate=False,
- use_group_of=False,
- eval_mode='area'):
- """Evaluate mAP of a dataset.
- Args:
- det_results (list[list]): [[cls1_det, cls2_det, ...], ...].
- The outer list indicates images, and the inner list indicates
- per-class detected bboxes.
- annotations (list[dict]): Ground truth annotations where each item of
- the list indicates an image. Keys of annotations are:
- - `bboxes`: numpy array of shape (n, 4)
- - `labels`: numpy array of shape (n, )
- - `bboxes_ignore` (optional): numpy array of shape (k, 4)
- - `labels_ignore` (optional): numpy array of shape (k, )
- scale_ranges (list[tuple] | None): Range of scales to be evaluated,
- in the format [(min1, max1), (min2, max2), ...]. A range of
- (32, 64) means the area range between (32**2, 64**2).
- Defaults to None.
- iou_thr (float): IoU threshold to be considered as matched.
- Defaults to 0.5.
- ioa_thr (float | None): IoA threshold to be considered as matched,
- which only used in OpenImages evaluation. Defaults to None.
- dataset (list[str] | str | None): Dataset name or dataset classes,
- there are minor differences in metrics for different datasets, e.g.
- "voc", "imagenet_det", etc. Defaults to None.
- logger (logging.Logger | str | None): The way to print the mAP
- summary. See `mmengine.logging.print_log()` for details.
- Defaults to None.
- tpfp_fn (callable | None): The function used to determine true/
- false positives. If None, :func:`tpfp_default` is used as default
- unless dataset is 'det' or 'vid' (:func:`tpfp_imagenet` in this
- case). If it is given as a function, then this function is used
- to evaluate tp & fp. Default None.
- nproc (int): Processes used for computing TP and FP.
- Defaults to 4.
- use_legacy_coordinate (bool): Whether to use coordinate system in
- mmdet v1.x. which means width, height should be
- calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively.
- Defaults to False.
- use_group_of (bool): Whether to use group of when calculate TP and FP,
- which only used in OpenImages evaluation. Defaults to False.
- eval_mode (str): 'area' or '11points', 'area' means calculating the
- area under precision-recall curve, '11points' means calculating
- the average precision of recalls at [0, 0.1, ..., 1],
- PASCAL VOC2007 uses `11points` as default evaluate mode, while
- others are 'area'. Defaults to 'area'.
- Returns:
- tuple: (mAP, [dict, dict, ...])
- """
- assert len(det_results) == len(annotations)
- assert eval_mode in ['area', '11points'], \
- f'Unrecognized {eval_mode} mode, only "area" and "11points" ' \
- 'are supported'
- if not use_legacy_coordinate:
- extra_length = 0.
- else:
- extra_length = 1.
- num_imgs = len(det_results)
- num_scales = len(scale_ranges) if scale_ranges is not None else 1
- num_classes = len(det_results[0]) # positive class num
- area_ranges = ([(rg[0]**2, rg[1]**2) for rg in scale_ranges]
- if scale_ranges is not None else None)
- # There is no need to use multi processes to process
- # when num_imgs = 1 .
- if num_imgs > 1:
- assert nproc > 0, 'nproc must be at least one.'
- nproc = min(nproc, num_imgs)
- pool = Pool(nproc)
- eval_results = []
- for i in range(num_classes):
- # get gt and det bboxes of this class
- cls_dets, cls_gts, cls_gts_ignore = get_cls_results(
- det_results, annotations, i)
- # choose proper function according to datasets to compute tp and fp
- if tpfp_fn is None:
- if dataset in ['det', 'vid']:
- tpfp_fn = tpfp_imagenet
- elif dataset in ['oid_challenge', 'oid_v6'] \
- or use_group_of is True:
- tpfp_fn = tpfp_openimages
- else:
- tpfp_fn = tpfp_default
- if not callable(tpfp_fn):
- raise ValueError(
- f'tpfp_fn has to be a function or None, but got {tpfp_fn}')
- if num_imgs > 1:
- # compute tp and fp for each image with multiple processes
- args = []
- if use_group_of:
- # used in Open Images Dataset evaluation
- gt_group_ofs = get_cls_group_ofs(annotations, i)
- args.append(gt_group_ofs)
- args.append([use_group_of for _ in range(num_imgs)])
- if ioa_thr is not None:
- args.append([ioa_thr for _ in range(num_imgs)])
- tpfp = pool.starmap(
- tpfp_fn,
- zip(cls_dets, cls_gts, cls_gts_ignore,
- [iou_thr for _ in range(num_imgs)],
- [area_ranges for _ in range(num_imgs)],
- [use_legacy_coordinate for _ in range(num_imgs)], *args))
- else:
- tpfp = tpfp_fn(
- cls_dets[0],
- cls_gts[0],
- cls_gts_ignore[0],
- iou_thr,
- area_ranges,
- use_legacy_coordinate,
- gt_bboxes_group_of=(get_cls_group_ofs(annotations, i)[0]
- if use_group_of else None),
- use_group_of=use_group_of,
- ioa_thr=ioa_thr)
- tpfp = [tpfp]
- if use_group_of:
- tp, fp, cls_dets = tuple(zip(*tpfp))
- else:
- tp, fp = tuple(zip(*tpfp))
- # calculate gt number of each scale
- # ignored gts or gts beyond the specific scale are not counted
- num_gts = np.zeros(num_scales, dtype=int)
- for j, bbox in enumerate(cls_gts):
- if area_ranges is None:
- num_gts[0] += bbox.shape[0]
- else:
- gt_areas = (bbox[:, 2] - bbox[:, 0] + extra_length) * (
- bbox[:, 3] - bbox[:, 1] + extra_length)
- for k, (min_area, max_area) in enumerate(area_ranges):
- num_gts[k] += np.sum((gt_areas >= min_area)
- & (gt_areas < max_area))
- # sort all det bboxes by score, also sort tp and fp
- cls_dets = np.vstack(cls_dets)
- num_dets = cls_dets.shape[0]
- sort_inds = np.argsort(-cls_dets[:, -1])
- tp = np.hstack(tp)[:, sort_inds]
- fp = np.hstack(fp)[:, sort_inds]
- # calculate recall and precision with tp and fp
- tp = np.cumsum(tp, axis=1)
- fp = np.cumsum(fp, axis=1)
- eps = np.finfo(np.float32).eps
- recalls = tp / np.maximum(num_gts[:, np.newaxis], eps)
- precisions = tp / np.maximum((tp + fp), eps)
- # calculate AP
- if scale_ranges is None:
- recalls = recalls[0, :]
- precisions = precisions[0, :]
- num_gts = num_gts.item()
- ap = average_precision(recalls, precisions, eval_mode)
- eval_results.append({
- 'num_gts': num_gts,
- 'num_dets': num_dets,
- 'recall': recalls,
- 'precision': precisions,
- 'ap': ap
- })
- if num_imgs > 1:
- pool.close()
- if scale_ranges is not None:
- # shape (num_classes, num_scales)
- all_ap = np.vstack([cls_result['ap'] for cls_result in eval_results])
- all_num_gts = np.vstack(
- [cls_result['num_gts'] for cls_result in eval_results])
- mean_ap = []
- for i in range(num_scales):
- if np.any(all_num_gts[:, i] > 0):
- mean_ap.append(all_ap[all_num_gts[:, i] > 0, i].mean())
- else:
- mean_ap.append(0.0)
- else:
- aps = []
- for cls_result in eval_results:
- if cls_result['num_gts'] > 0:
- aps.append(cls_result['ap'])
- mean_ap = np.array(aps).mean().item() if aps else 0.0
- print_map_summary(
- mean_ap, eval_results, dataset, area_ranges, logger=logger)
- return mean_ap, eval_results
- def print_map_summary(mean_ap,
- results,
- dataset=None,
- scale_ranges=None,
- logger=None):
- """Print mAP and results of each class.
- A table will be printed to show the gts/dets/recall/AP of each class and
- the mAP.
- Args:
- mean_ap (float): Calculated from `eval_map()`.
- results (list[dict]): Calculated from `eval_map()`.
- dataset (list[str] | str | None): Dataset name or dataset classes.
- scale_ranges (list[tuple] | None): Range of scales to be evaluated.
- logger (logging.Logger | str | None): The way to print the mAP
- summary. See `mmengine.logging.print_log()` for details.
- Defaults to None.
- """
- if logger == 'silent':
- return
- if isinstance(results[0]['ap'], np.ndarray):
- num_scales = len(results[0]['ap'])
- else:
- num_scales = 1
- if scale_ranges is not None:
- assert len(scale_ranges) == num_scales
- num_classes = len(results)
- recalls = np.zeros((num_scales, num_classes), dtype=np.float32)
- aps = np.zeros((num_scales, num_classes), dtype=np.float32)
- num_gts = np.zeros((num_scales, num_classes), dtype=int)
- for i, cls_result in enumerate(results):
- if cls_result['recall'].size > 0:
- recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1]
- aps[:, i] = cls_result['ap']
- num_gts[:, i] = cls_result['num_gts']
- if dataset is None:
- label_names = [str(i) for i in range(num_classes)]
- elif is_str(dataset):
- label_names = get_classes(dataset)
- else:
- label_names = dataset
- if not isinstance(mean_ap, list):
- mean_ap = [mean_ap]
- header = ['class', 'gts', 'dets', 'recall', 'ap']
- for i in range(num_scales):
- if scale_ranges is not None:
- print_log(f'Scale range {scale_ranges[i]}', logger=logger)
- table_data = [header]
- for j in range(num_classes):
- row_data = [
- label_names[j], num_gts[i, j], results[j]['num_dets'],
- f'{recalls[i, j]:.3f}', f'{aps[i, j]:.3f}'
- ]
- table_data.append(row_data)
- table_data.append(['mAP', '', '', '', f'{mean_ap[i]:.3f}'])
- table = AsciiTable(table_data)
- table.inner_footing_row_border = True
- print_log('\n' + table.table, logger=logger)
|