mean_ap.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from multiprocessing import Pool
  3. import numpy as np
  4. from mmengine.logging import print_log
  5. from mmengine.utils import is_str
  6. from terminaltables import AsciiTable
  7. from .bbox_overlaps import bbox_overlaps
  8. from .class_names import get_classes
  9. def average_precision(recalls, precisions, mode='area'):
  10. """Calculate average precision (for single or multiple scales).
  11. Args:
  12. recalls (ndarray): shape (num_scales, num_dets) or (num_dets, )
  13. precisions (ndarray): shape (num_scales, num_dets) or (num_dets, )
  14. mode (str): 'area' or '11points', 'area' means calculating the area
  15. under precision-recall curve, '11points' means calculating
  16. the average precision of recalls at [0, 0.1, ..., 1]
  17. Returns:
  18. float or ndarray: calculated average precision
  19. """
  20. no_scale = False
  21. if recalls.ndim == 1:
  22. no_scale = True
  23. recalls = recalls[np.newaxis, :]
  24. precisions = precisions[np.newaxis, :]
  25. assert recalls.shape == precisions.shape and recalls.ndim == 2
  26. num_scales = recalls.shape[0]
  27. ap = np.zeros(num_scales, dtype=np.float32)
  28. if mode == 'area':
  29. zeros = np.zeros((num_scales, 1), dtype=recalls.dtype)
  30. ones = np.ones((num_scales, 1), dtype=recalls.dtype)
  31. mrec = np.hstack((zeros, recalls, ones))
  32. mpre = np.hstack((zeros, precisions, zeros))
  33. for i in range(mpre.shape[1] - 1, 0, -1):
  34. mpre[:, i - 1] = np.maximum(mpre[:, i - 1], mpre[:, i])
  35. for i in range(num_scales):
  36. ind = np.where(mrec[i, 1:] != mrec[i, :-1])[0]
  37. ap[i] = np.sum(
  38. (mrec[i, ind + 1] - mrec[i, ind]) * mpre[i, ind + 1])
  39. elif mode == '11points':
  40. for i in range(num_scales):
  41. for thr in np.arange(0, 1 + 1e-3, 0.1):
  42. precs = precisions[i, recalls[i, :] >= thr]
  43. prec = precs.max() if precs.size > 0 else 0
  44. ap[i] += prec
  45. ap /= 11
  46. else:
  47. raise ValueError(
  48. 'Unrecognized mode, only "area" and "11points" are supported')
  49. if no_scale:
  50. ap = ap[0]
  51. return ap
  52. def tpfp_imagenet(det_bboxes,
  53. gt_bboxes,
  54. gt_bboxes_ignore=None,
  55. default_iou_thr=0.5,
  56. area_ranges=None,
  57. use_legacy_coordinate=False,
  58. **kwargs):
  59. """Check if detected bboxes are true positive or false positive.
  60. Args:
  61. det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5).
  62. gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4).
  63. gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image,
  64. of shape (k, 4). Defaults to None
  65. default_iou_thr (float): IoU threshold to be considered as matched for
  66. medium and large bboxes (small ones have special rules).
  67. Defaults to 0.5.
  68. area_ranges (list[tuple] | None): Range of bbox areas to be evaluated,
  69. in the format [(min1, max1), (min2, max2), ...]. Defaults to None.
  70. use_legacy_coordinate (bool): Whether to use coordinate system in
  71. mmdet v1.x. which means width, height should be
  72. calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively.
  73. Defaults to False.
  74. Returns:
  75. tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of
  76. each array is (num_scales, m).
  77. """
  78. if not use_legacy_coordinate:
  79. extra_length = 0.
  80. else:
  81. extra_length = 1.
  82. # an indicator of ignored gts
  83. gt_ignore_inds = np.concatenate(
  84. (np.zeros(gt_bboxes.shape[0],
  85. dtype=bool), np.ones(gt_bboxes_ignore.shape[0], dtype=bool)))
  86. # stack gt_bboxes and gt_bboxes_ignore for convenience
  87. gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore))
  88. num_dets = det_bboxes.shape[0]
  89. num_gts = gt_bboxes.shape[0]
  90. if area_ranges is None:
  91. area_ranges = [(None, None)]
  92. num_scales = len(area_ranges)
  93. # tp and fp are of shape (num_scales, num_gts), each row is tp or fp
  94. # of a certain scale.
  95. tp = np.zeros((num_scales, num_dets), dtype=np.float32)
  96. fp = np.zeros((num_scales, num_dets), dtype=np.float32)
  97. if gt_bboxes.shape[0] == 0:
  98. if area_ranges == [(None, None)]:
  99. fp[...] = 1
  100. else:
  101. det_areas = (
  102. det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * (
  103. det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length)
  104. for i, (min_area, max_area) in enumerate(area_ranges):
  105. fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
  106. return tp, fp
  107. ious = bbox_overlaps(
  108. det_bboxes, gt_bboxes - 1, use_legacy_coordinate=use_legacy_coordinate)
  109. gt_w = gt_bboxes[:, 2] - gt_bboxes[:, 0] + extra_length
  110. gt_h = gt_bboxes[:, 3] - gt_bboxes[:, 1] + extra_length
  111. iou_thrs = np.minimum((gt_w * gt_h) / ((gt_w + 10.0) * (gt_h + 10.0)),
  112. default_iou_thr)
  113. # sort all detections by scores in descending order
  114. sort_inds = np.argsort(-det_bboxes[:, -1])
  115. for k, (min_area, max_area) in enumerate(area_ranges):
  116. gt_covered = np.zeros(num_gts, dtype=bool)
  117. # if no area range is specified, gt_area_ignore is all False
  118. if min_area is None:
  119. gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
  120. else:
  121. gt_areas = gt_w * gt_h
  122. gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area)
  123. for i in sort_inds:
  124. max_iou = -1
  125. matched_gt = -1
  126. # find best overlapped available gt
  127. for j in range(num_gts):
  128. # different from PASCAL VOC: allow finding other gts if the
  129. # best overlapped ones are already matched by other det bboxes
  130. if gt_covered[j]:
  131. continue
  132. elif ious[i, j] >= iou_thrs[j] and ious[i, j] > max_iou:
  133. max_iou = ious[i, j]
  134. matched_gt = j
  135. # there are 4 cases for a det bbox:
  136. # 1. it matches a gt, tp = 1, fp = 0
  137. # 2. it matches an ignored gt, tp = 0, fp = 0
  138. # 3. it matches no gt and within area range, tp = 0, fp = 1
  139. # 4. it matches no gt but is beyond area range, tp = 0, fp = 0
  140. if matched_gt >= 0:
  141. gt_covered[matched_gt] = 1
  142. if not (gt_ignore_inds[matched_gt]
  143. or gt_area_ignore[matched_gt]):
  144. tp[k, i] = 1
  145. elif min_area is None:
  146. fp[k, i] = 1
  147. else:
  148. bbox = det_bboxes[i, :4]
  149. area = (bbox[2] - bbox[0] + extra_length) * (
  150. bbox[3] - bbox[1] + extra_length)
  151. if area >= min_area and area < max_area:
  152. fp[k, i] = 1
  153. return tp, fp
  154. def tpfp_default(det_bboxes,
  155. gt_bboxes,
  156. gt_bboxes_ignore=None,
  157. iou_thr=0.5,
  158. area_ranges=None,
  159. use_legacy_coordinate=False,
  160. **kwargs):
  161. """Check if detected bboxes are true positive or false positive.
  162. Args:
  163. det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5).
  164. gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4).
  165. gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image,
  166. of shape (k, 4). Defaults to None
  167. iou_thr (float): IoU threshold to be considered as matched.
  168. Defaults to 0.5.
  169. area_ranges (list[tuple] | None): Range of bbox areas to be
  170. evaluated, in the format [(min1, max1), (min2, max2), ...].
  171. Defaults to None.
  172. use_legacy_coordinate (bool): Whether to use coordinate system in
  173. mmdet v1.x. which means width, height should be
  174. calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively.
  175. Defaults to False.
  176. Returns:
  177. tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of
  178. each array is (num_scales, m).
  179. """
  180. if not use_legacy_coordinate:
  181. extra_length = 0.
  182. else:
  183. extra_length = 1.
  184. # an indicator of ignored gts
  185. gt_ignore_inds = np.concatenate(
  186. (np.zeros(gt_bboxes.shape[0],
  187. dtype=bool), np.ones(gt_bboxes_ignore.shape[0], dtype=bool)))
  188. # stack gt_bboxes and gt_bboxes_ignore for convenience
  189. gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore))
  190. num_dets = det_bboxes.shape[0]
  191. num_gts = gt_bboxes.shape[0]
  192. if area_ranges is None:
  193. area_ranges = [(None, None)]
  194. num_scales = len(area_ranges)
  195. # tp and fp are of shape (num_scales, num_gts), each row is tp or fp of
  196. # a certain scale
  197. tp = np.zeros((num_scales, num_dets), dtype=np.float32)
  198. fp = np.zeros((num_scales, num_dets), dtype=np.float32)
  199. # if there is no gt bboxes in this image, then all det bboxes
  200. # within area range are false positives
  201. if gt_bboxes.shape[0] == 0:
  202. if area_ranges == [(None, None)]:
  203. fp[...] = 1
  204. else:
  205. det_areas = (
  206. det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * (
  207. det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length)
  208. for i, (min_area, max_area) in enumerate(area_ranges):
  209. fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
  210. return tp, fp
  211. ious = bbox_overlaps(
  212. det_bboxes, gt_bboxes, use_legacy_coordinate=use_legacy_coordinate)
  213. # for each det, the max iou with all gts
  214. ious_max = ious.max(axis=1)
  215. # for each det, which gt overlaps most with it
  216. ious_argmax = ious.argmax(axis=1)
  217. # sort all dets in descending order by scores
  218. sort_inds = np.argsort(-det_bboxes[:, -1])
  219. for k, (min_area, max_area) in enumerate(area_ranges):
  220. gt_covered = np.zeros(num_gts, dtype=bool)
  221. # if no area range is specified, gt_area_ignore is all False
  222. if min_area is None:
  223. gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
  224. else:
  225. gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0] + extra_length) * (
  226. gt_bboxes[:, 3] - gt_bboxes[:, 1] + extra_length)
  227. gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area)
  228. for i in sort_inds:
  229. if ious_max[i] >= iou_thr:
  230. matched_gt = ious_argmax[i]
  231. if not (gt_ignore_inds[matched_gt]
  232. or gt_area_ignore[matched_gt]):
  233. if not gt_covered[matched_gt]:
  234. gt_covered[matched_gt] = True
  235. tp[k, i] = 1
  236. else:
  237. fp[k, i] = 1
  238. # otherwise ignore this detected bbox, tp = 0, fp = 0
  239. elif min_area is None:
  240. fp[k, i] = 1
  241. else:
  242. bbox = det_bboxes[i, :4]
  243. area = (bbox[2] - bbox[0] + extra_length) * (
  244. bbox[3] - bbox[1] + extra_length)
  245. if area >= min_area and area < max_area:
  246. fp[k, i] = 1
  247. return tp, fp
  248. def tpfp_openimages(det_bboxes,
  249. gt_bboxes,
  250. gt_bboxes_ignore=None,
  251. iou_thr=0.5,
  252. area_ranges=None,
  253. use_legacy_coordinate=False,
  254. gt_bboxes_group_of=None,
  255. use_group_of=True,
  256. ioa_thr=0.5,
  257. **kwargs):
  258. """Check if detected bboxes are true positive or false positive.
  259. Args:
  260. det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5).
  261. gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4).
  262. gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image,
  263. of shape (k, 4). Defaults to None
  264. iou_thr (float): IoU threshold to be considered as matched.
  265. Defaults to 0.5.
  266. area_ranges (list[tuple] | None): Range of bbox areas to be
  267. evaluated, in the format [(min1, max1), (min2, max2), ...].
  268. Defaults to None.
  269. use_legacy_coordinate (bool): Whether to use coordinate system in
  270. mmdet v1.x. which means width, height should be
  271. calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively.
  272. Defaults to False.
  273. gt_bboxes_group_of (ndarray): GT group_of of this image, of shape
  274. (k, 1). Defaults to None
  275. use_group_of (bool): Whether to use group of when calculate TP and FP,
  276. which only used in OpenImages evaluation. Defaults to True.
  277. ioa_thr (float | None): IoA threshold to be considered as matched,
  278. which only used in OpenImages evaluation. Defaults to 0.5.
  279. Returns:
  280. tuple[np.ndarray]: Returns a tuple (tp, fp, det_bboxes), where
  281. (tp, fp) whose elements are 0 and 1. The shape of each array is
  282. (num_scales, m). (det_bboxes) whose will filter those are not
  283. matched by group of gts when processing Open Images evaluation.
  284. The shape is (num_scales, m).
  285. """
  286. if not use_legacy_coordinate:
  287. extra_length = 0.
  288. else:
  289. extra_length = 1.
  290. # an indicator of ignored gts
  291. gt_ignore_inds = np.concatenate(
  292. (np.zeros(gt_bboxes.shape[0],
  293. dtype=bool), np.ones(gt_bboxes_ignore.shape[0], dtype=bool)))
  294. # stack gt_bboxes and gt_bboxes_ignore for convenience
  295. gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore))
  296. num_dets = det_bboxes.shape[0]
  297. num_gts = gt_bboxes.shape[0]
  298. if area_ranges is None:
  299. area_ranges = [(None, None)]
  300. num_scales = len(area_ranges)
  301. # tp and fp are of shape (num_scales, num_gts), each row is tp or fp of
  302. # a certain scale
  303. tp = np.zeros((num_scales, num_dets), dtype=np.float32)
  304. fp = np.zeros((num_scales, num_dets), dtype=np.float32)
  305. # if there is no gt bboxes in this image, then all det bboxes
  306. # within area range are false positives
  307. if gt_bboxes.shape[0] == 0:
  308. if area_ranges == [(None, None)]:
  309. fp[...] = 1
  310. else:
  311. det_areas = (
  312. det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * (
  313. det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length)
  314. for i, (min_area, max_area) in enumerate(area_ranges):
  315. fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
  316. return tp, fp, det_bboxes
  317. if gt_bboxes_group_of is not None and use_group_of:
  318. # if handle group-of boxes, divided gt boxes into two parts:
  319. # non-group-of and group-of.Then calculate ious and ioas through
  320. # non-group-of group-of gts respectively. This only used in
  321. # OpenImages evaluation.
  322. assert gt_bboxes_group_of.shape[0] == gt_bboxes.shape[0]
  323. non_group_gt_bboxes = gt_bboxes[~gt_bboxes_group_of]
  324. group_gt_bboxes = gt_bboxes[gt_bboxes_group_of]
  325. num_gts_group = group_gt_bboxes.shape[0]
  326. ious = bbox_overlaps(det_bboxes, non_group_gt_bboxes)
  327. ioas = bbox_overlaps(det_bboxes, group_gt_bboxes, mode='iof')
  328. else:
  329. # if not consider group-of boxes, only calculate ious through gt boxes
  330. ious = bbox_overlaps(
  331. det_bboxes, gt_bboxes, use_legacy_coordinate=use_legacy_coordinate)
  332. ioas = None
  333. if ious.shape[1] > 0:
  334. # for each det, the max iou with all gts
  335. ious_max = ious.max(axis=1)
  336. # for each det, which gt overlaps most with it
  337. ious_argmax = ious.argmax(axis=1)
  338. # sort all dets in descending order by scores
  339. sort_inds = np.argsort(-det_bboxes[:, -1])
  340. for k, (min_area, max_area) in enumerate(area_ranges):
  341. gt_covered = np.zeros(num_gts, dtype=bool)
  342. # if no area range is specified, gt_area_ignore is all False
  343. if min_area is None:
  344. gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
  345. else:
  346. gt_areas = (
  347. gt_bboxes[:, 2] - gt_bboxes[:, 0] + extra_length) * (
  348. gt_bboxes[:, 3] - gt_bboxes[:, 1] + extra_length)
  349. gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area)
  350. for i in sort_inds:
  351. if ious_max[i] >= iou_thr:
  352. matched_gt = ious_argmax[i]
  353. if not (gt_ignore_inds[matched_gt]
  354. or gt_area_ignore[matched_gt]):
  355. if not gt_covered[matched_gt]:
  356. gt_covered[matched_gt] = True
  357. tp[k, i] = 1
  358. else:
  359. fp[k, i] = 1
  360. # otherwise ignore this detected bbox, tp = 0, fp = 0
  361. elif min_area is None:
  362. fp[k, i] = 1
  363. else:
  364. bbox = det_bboxes[i, :4]
  365. area = (bbox[2] - bbox[0] + extra_length) * (
  366. bbox[3] - bbox[1] + extra_length)
  367. if area >= min_area and area < max_area:
  368. fp[k, i] = 1
  369. else:
  370. # if there is no no-group-of gt bboxes in this image,
  371. # then all det bboxes within area range are false positives.
  372. # Only used in OpenImages evaluation.
  373. if area_ranges == [(None, None)]:
  374. fp[...] = 1
  375. else:
  376. det_areas = (
  377. det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * (
  378. det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length)
  379. for i, (min_area, max_area) in enumerate(area_ranges):
  380. fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
  381. if ioas is None or ioas.shape[1] <= 0:
  382. return tp, fp, det_bboxes
  383. else:
  384. # The evaluation of group-of TP and FP are done in two stages:
  385. # 1. All detections are first matched to non group-of boxes; true
  386. # positives are determined.
  387. # 2. Detections that are determined as false positives are matched
  388. # against group-of boxes and calculated group-of TP and FP.
  389. # Only used in OpenImages evaluation.
  390. det_bboxes_group = np.zeros(
  391. (num_scales, ioas.shape[1], det_bboxes.shape[1]), dtype=float)
  392. match_group_of = np.zeros((num_scales, num_dets), dtype=bool)
  393. tp_group = np.zeros((num_scales, num_gts_group), dtype=np.float32)
  394. ioas_max = ioas.max(axis=1)
  395. # for each det, which gt overlaps most with it
  396. ioas_argmax = ioas.argmax(axis=1)
  397. # sort all dets in descending order by scores
  398. sort_inds = np.argsort(-det_bboxes[:, -1])
  399. for k, (min_area, max_area) in enumerate(area_ranges):
  400. box_is_covered = tp[k]
  401. # if no area range is specified, gt_area_ignore is all False
  402. if min_area is None:
  403. gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
  404. else:
  405. gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (
  406. gt_bboxes[:, 3] - gt_bboxes[:, 1])
  407. gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area)
  408. for i in sort_inds:
  409. matched_gt = ioas_argmax[i]
  410. if not box_is_covered[i]:
  411. if ioas_max[i] >= ioa_thr:
  412. if not (gt_ignore_inds[matched_gt]
  413. or gt_area_ignore[matched_gt]):
  414. if not tp_group[k, matched_gt]:
  415. tp_group[k, matched_gt] = 1
  416. match_group_of[k, i] = True
  417. else:
  418. match_group_of[k, i] = True
  419. if det_bboxes_group[k, matched_gt, -1] < \
  420. det_bboxes[i, -1]:
  421. det_bboxes_group[k, matched_gt] = \
  422. det_bboxes[i]
  423. fp_group = (tp_group <= 0).astype(float)
  424. tps = []
  425. fps = []
  426. # concatenate tp, fp, and det-boxes which not matched group of
  427. # gt boxes and tp_group, fp_group, and det_bboxes_group which
  428. # matched group of boxes respectively.
  429. for i in range(num_scales):
  430. tps.append(
  431. np.concatenate((tp[i][~match_group_of[i]], tp_group[i])))
  432. fps.append(
  433. np.concatenate((fp[i][~match_group_of[i]], fp_group[i])))
  434. det_bboxes = np.concatenate(
  435. (det_bboxes[~match_group_of[i]], det_bboxes_group[i]))
  436. tp = np.vstack(tps)
  437. fp = np.vstack(fps)
  438. return tp, fp, det_bboxes
  439. def get_cls_results(det_results, annotations, class_id):
  440. """Get det results and gt information of a certain class.
  441. Args:
  442. det_results (list[list]): Same as `eval_map()`.
  443. annotations (list[dict]): Same as `eval_map()`.
  444. class_id (int): ID of a specific class.
  445. Returns:
  446. tuple[list[np.ndarray]]: detected bboxes, gt bboxes, ignored gt bboxes
  447. """
  448. cls_dets = [img_res[class_id] for img_res in det_results]
  449. cls_gts = []
  450. cls_gts_ignore = []
  451. for ann in annotations:
  452. gt_inds = ann['labels'] == class_id
  453. cls_gts.append(ann['bboxes'][gt_inds, :])
  454. if ann.get('labels_ignore', None) is not None:
  455. ignore_inds = ann['labels_ignore'] == class_id
  456. cls_gts_ignore.append(ann['bboxes_ignore'][ignore_inds, :])
  457. else:
  458. cls_gts_ignore.append(np.empty((0, 4), dtype=np.float32))
  459. return cls_dets, cls_gts, cls_gts_ignore
  460. def get_cls_group_ofs(annotations, class_id):
  461. """Get `gt_group_of` of a certain class, which is used in Open Images.
  462. Args:
  463. annotations (list[dict]): Same as `eval_map()`.
  464. class_id (int): ID of a specific class.
  465. Returns:
  466. list[np.ndarray]: `gt_group_of` of a certain class.
  467. """
  468. gt_group_ofs = []
  469. for ann in annotations:
  470. gt_inds = ann['labels'] == class_id
  471. if ann.get('gt_is_group_ofs', None) is not None:
  472. gt_group_ofs.append(ann['gt_is_group_ofs'][gt_inds])
  473. else:
  474. gt_group_ofs.append(np.empty((0, 1), dtype=bool))
  475. return gt_group_ofs
  476. def eval_map(det_results,
  477. annotations,
  478. scale_ranges=None,
  479. iou_thr=0.5,
  480. ioa_thr=None,
  481. dataset=None,
  482. logger=None,
  483. tpfp_fn=None,
  484. nproc=4,
  485. use_legacy_coordinate=False,
  486. use_group_of=False,
  487. eval_mode='area'):
  488. """Evaluate mAP of a dataset.
  489. Args:
  490. det_results (list[list]): [[cls1_det, cls2_det, ...], ...].
  491. The outer list indicates images, and the inner list indicates
  492. per-class detected bboxes.
  493. annotations (list[dict]): Ground truth annotations where each item of
  494. the list indicates an image. Keys of annotations are:
  495. - `bboxes`: numpy array of shape (n, 4)
  496. - `labels`: numpy array of shape (n, )
  497. - `bboxes_ignore` (optional): numpy array of shape (k, 4)
  498. - `labels_ignore` (optional): numpy array of shape (k, )
  499. scale_ranges (list[tuple] | None): Range of scales to be evaluated,
  500. in the format [(min1, max1), (min2, max2), ...]. A range of
  501. (32, 64) means the area range between (32**2, 64**2).
  502. Defaults to None.
  503. iou_thr (float): IoU threshold to be considered as matched.
  504. Defaults to 0.5.
  505. ioa_thr (float | None): IoA threshold to be considered as matched,
  506. which only used in OpenImages evaluation. Defaults to None.
  507. dataset (list[str] | str | None): Dataset name or dataset classes,
  508. there are minor differences in metrics for different datasets, e.g.
  509. "voc", "imagenet_det", etc. Defaults to None.
  510. logger (logging.Logger | str | None): The way to print the mAP
  511. summary. See `mmengine.logging.print_log()` for details.
  512. Defaults to None.
  513. tpfp_fn (callable | None): The function used to determine true/
  514. false positives. If None, :func:`tpfp_default` is used as default
  515. unless dataset is 'det' or 'vid' (:func:`tpfp_imagenet` in this
  516. case). If it is given as a function, then this function is used
  517. to evaluate tp & fp. Default None.
  518. nproc (int): Processes used for computing TP and FP.
  519. Defaults to 4.
  520. use_legacy_coordinate (bool): Whether to use coordinate system in
  521. mmdet v1.x. which means width, height should be
  522. calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively.
  523. Defaults to False.
  524. use_group_of (bool): Whether to use group of when calculate TP and FP,
  525. which only used in OpenImages evaluation. Defaults to False.
  526. eval_mode (str): 'area' or '11points', 'area' means calculating the
  527. area under precision-recall curve, '11points' means calculating
  528. the average precision of recalls at [0, 0.1, ..., 1],
  529. PASCAL VOC2007 uses `11points` as default evaluate mode, while
  530. others are 'area'. Defaults to 'area'.
  531. Returns:
  532. tuple: (mAP, [dict, dict, ...])
  533. """
  534. assert len(det_results) == len(annotations)
  535. assert eval_mode in ['area', '11points'], \
  536. f'Unrecognized {eval_mode} mode, only "area" and "11points" ' \
  537. 'are supported'
  538. if not use_legacy_coordinate:
  539. extra_length = 0.
  540. else:
  541. extra_length = 1.
  542. num_imgs = len(det_results)
  543. num_scales = len(scale_ranges) if scale_ranges is not None else 1
  544. num_classes = len(det_results[0]) # positive class num
  545. area_ranges = ([(rg[0]**2, rg[1]**2) for rg in scale_ranges]
  546. if scale_ranges is not None else None)
  547. # There is no need to use multi processes to process
  548. # when num_imgs = 1 .
  549. if num_imgs > 1:
  550. assert nproc > 0, 'nproc must be at least one.'
  551. nproc = min(nproc, num_imgs)
  552. pool = Pool(nproc)
  553. eval_results = []
  554. for i in range(num_classes):
  555. # get gt and det bboxes of this class
  556. cls_dets, cls_gts, cls_gts_ignore = get_cls_results(
  557. det_results, annotations, i)
  558. # choose proper function according to datasets to compute tp and fp
  559. if tpfp_fn is None:
  560. if dataset in ['det', 'vid']:
  561. tpfp_fn = tpfp_imagenet
  562. elif dataset in ['oid_challenge', 'oid_v6'] \
  563. or use_group_of is True:
  564. tpfp_fn = tpfp_openimages
  565. else:
  566. tpfp_fn = tpfp_default
  567. if not callable(tpfp_fn):
  568. raise ValueError(
  569. f'tpfp_fn has to be a function or None, but got {tpfp_fn}')
  570. if num_imgs > 1:
  571. # compute tp and fp for each image with multiple processes
  572. args = []
  573. if use_group_of:
  574. # used in Open Images Dataset evaluation
  575. gt_group_ofs = get_cls_group_ofs(annotations, i)
  576. args.append(gt_group_ofs)
  577. args.append([use_group_of for _ in range(num_imgs)])
  578. if ioa_thr is not None:
  579. args.append([ioa_thr for _ in range(num_imgs)])
  580. tpfp = pool.starmap(
  581. tpfp_fn,
  582. zip(cls_dets, cls_gts, cls_gts_ignore,
  583. [iou_thr for _ in range(num_imgs)],
  584. [area_ranges for _ in range(num_imgs)],
  585. [use_legacy_coordinate for _ in range(num_imgs)], *args))
  586. else:
  587. tpfp = tpfp_fn(
  588. cls_dets[0],
  589. cls_gts[0],
  590. cls_gts_ignore[0],
  591. iou_thr,
  592. area_ranges,
  593. use_legacy_coordinate,
  594. gt_bboxes_group_of=(get_cls_group_ofs(annotations, i)[0]
  595. if use_group_of else None),
  596. use_group_of=use_group_of,
  597. ioa_thr=ioa_thr)
  598. tpfp = [tpfp]
  599. if use_group_of:
  600. tp, fp, cls_dets = tuple(zip(*tpfp))
  601. else:
  602. tp, fp = tuple(zip(*tpfp))
  603. # calculate gt number of each scale
  604. # ignored gts or gts beyond the specific scale are not counted
  605. num_gts = np.zeros(num_scales, dtype=int)
  606. for j, bbox in enumerate(cls_gts):
  607. if area_ranges is None:
  608. num_gts[0] += bbox.shape[0]
  609. else:
  610. gt_areas = (bbox[:, 2] - bbox[:, 0] + extra_length) * (
  611. bbox[:, 3] - bbox[:, 1] + extra_length)
  612. for k, (min_area, max_area) in enumerate(area_ranges):
  613. num_gts[k] += np.sum((gt_areas >= min_area)
  614. & (gt_areas < max_area))
  615. # sort all det bboxes by score, also sort tp and fp
  616. cls_dets = np.vstack(cls_dets)
  617. num_dets = cls_dets.shape[0]
  618. sort_inds = np.argsort(-cls_dets[:, -1])
  619. tp = np.hstack(tp)[:, sort_inds]
  620. fp = np.hstack(fp)[:, sort_inds]
  621. # calculate recall and precision with tp and fp
  622. tp = np.cumsum(tp, axis=1)
  623. fp = np.cumsum(fp, axis=1)
  624. eps = np.finfo(np.float32).eps
  625. recalls = tp / np.maximum(num_gts[:, np.newaxis], eps)
  626. precisions = tp / np.maximum((tp + fp), eps)
  627. # calculate AP
  628. if scale_ranges is None:
  629. recalls = recalls[0, :]
  630. precisions = precisions[0, :]
  631. num_gts = num_gts.item()
  632. ap = average_precision(recalls, precisions, eval_mode)
  633. eval_results.append({
  634. 'num_gts': num_gts,
  635. 'num_dets': num_dets,
  636. 'recall': recalls,
  637. 'precision': precisions,
  638. 'ap': ap
  639. })
  640. if num_imgs > 1:
  641. pool.close()
  642. if scale_ranges is not None:
  643. # shape (num_classes, num_scales)
  644. all_ap = np.vstack([cls_result['ap'] for cls_result in eval_results])
  645. all_num_gts = np.vstack(
  646. [cls_result['num_gts'] for cls_result in eval_results])
  647. mean_ap = []
  648. for i in range(num_scales):
  649. if np.any(all_num_gts[:, i] > 0):
  650. mean_ap.append(all_ap[all_num_gts[:, i] > 0, i].mean())
  651. else:
  652. mean_ap.append(0.0)
  653. else:
  654. aps = []
  655. for cls_result in eval_results:
  656. if cls_result['num_gts'] > 0:
  657. aps.append(cls_result['ap'])
  658. mean_ap = np.array(aps).mean().item() if aps else 0.0
  659. print_map_summary(
  660. mean_ap, eval_results, dataset, area_ranges, logger=logger)
  661. return mean_ap, eval_results
  662. def print_map_summary(mean_ap,
  663. results,
  664. dataset=None,
  665. scale_ranges=None,
  666. logger=None):
  667. """Print mAP and results of each class.
  668. A table will be printed to show the gts/dets/recall/AP of each class and
  669. the mAP.
  670. Args:
  671. mean_ap (float): Calculated from `eval_map()`.
  672. results (list[dict]): Calculated from `eval_map()`.
  673. dataset (list[str] | str | None): Dataset name or dataset classes.
  674. scale_ranges (list[tuple] | None): Range of scales to be evaluated.
  675. logger (logging.Logger | str | None): The way to print the mAP
  676. summary. See `mmengine.logging.print_log()` for details.
  677. Defaults to None.
  678. """
  679. if logger == 'silent':
  680. return
  681. if isinstance(results[0]['ap'], np.ndarray):
  682. num_scales = len(results[0]['ap'])
  683. else:
  684. num_scales = 1
  685. if scale_ranges is not None:
  686. assert len(scale_ranges) == num_scales
  687. num_classes = len(results)
  688. recalls = np.zeros((num_scales, num_classes), dtype=np.float32)
  689. aps = np.zeros((num_scales, num_classes), dtype=np.float32)
  690. num_gts = np.zeros((num_scales, num_classes), dtype=int)
  691. for i, cls_result in enumerate(results):
  692. if cls_result['recall'].size > 0:
  693. recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1]
  694. aps[:, i] = cls_result['ap']
  695. num_gts[:, i] = cls_result['num_gts']
  696. if dataset is None:
  697. label_names = [str(i) for i in range(num_classes)]
  698. elif is_str(dataset):
  699. label_names = get_classes(dataset)
  700. else:
  701. label_names = dataset
  702. if not isinstance(mean_ap, list):
  703. mean_ap = [mean_ap]
  704. header = ['class', 'gts', 'dets', 'recall', 'ap']
  705. for i in range(num_scales):
  706. if scale_ranges is not None:
  707. print_log(f'Scale range {scale_ranges[i]}', logger=logger)
  708. table_data = [header]
  709. for j in range(num_classes):
  710. row_data = [
  711. label_names[j], num_gts[i, j], results[j]['num_dets'],
  712. f'{recalls[i, j]:.3f}', f'{aps[i, j]:.3f}'
  713. ]
  714. table_data.append(row_data)
  715. table_data.append(['mAP', '', '', '', f'{mean_ap[i]:.3f}'])
  716. table = AsciiTable(table_data)
  717. table.inner_footing_row_border = True
  718. print_log('\n' + table.table, logger=logger)