recall.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from collections.abc import Sequence
  3. import numpy as np
  4. from mmengine.logging import print_log
  5. from terminaltables import AsciiTable
  6. from .bbox_overlaps import bbox_overlaps
  7. def _recalls(all_ious, proposal_nums, thrs):
  8. img_num = all_ious.shape[0]
  9. total_gt_num = sum([ious.shape[0] for ious in all_ious])
  10. _ious = np.zeros((proposal_nums.size, total_gt_num), dtype=np.float32)
  11. for k, proposal_num in enumerate(proposal_nums):
  12. tmp_ious = np.zeros(0)
  13. for i in range(img_num):
  14. ious = all_ious[i][:, :proposal_num].copy()
  15. gt_ious = np.zeros((ious.shape[0]))
  16. if ious.size == 0:
  17. tmp_ious = np.hstack((tmp_ious, gt_ious))
  18. continue
  19. for j in range(ious.shape[0]):
  20. gt_max_overlaps = ious.argmax(axis=1)
  21. max_ious = ious[np.arange(0, ious.shape[0]), gt_max_overlaps]
  22. gt_idx = max_ious.argmax()
  23. gt_ious[j] = max_ious[gt_idx]
  24. box_idx = gt_max_overlaps[gt_idx]
  25. ious[gt_idx, :] = -1
  26. ious[:, box_idx] = -1
  27. tmp_ious = np.hstack((tmp_ious, gt_ious))
  28. _ious[k, :] = tmp_ious
  29. _ious = np.fliplr(np.sort(_ious, axis=1))
  30. recalls = np.zeros((proposal_nums.size, thrs.size))
  31. for i, thr in enumerate(thrs):
  32. recalls[:, i] = (_ious >= thr).sum(axis=1) / float(total_gt_num)
  33. return recalls
  34. def set_recall_param(proposal_nums, iou_thrs):
  35. """Check proposal_nums and iou_thrs and set correct format."""
  36. if isinstance(proposal_nums, Sequence):
  37. _proposal_nums = np.array(proposal_nums)
  38. elif isinstance(proposal_nums, int):
  39. _proposal_nums = np.array([proposal_nums])
  40. else:
  41. _proposal_nums = proposal_nums
  42. if iou_thrs is None:
  43. _iou_thrs = np.array([0.5])
  44. elif isinstance(iou_thrs, Sequence):
  45. _iou_thrs = np.array(iou_thrs)
  46. elif isinstance(iou_thrs, float):
  47. _iou_thrs = np.array([iou_thrs])
  48. else:
  49. _iou_thrs = iou_thrs
  50. return _proposal_nums, _iou_thrs
  51. def eval_recalls(gts,
  52. proposals,
  53. proposal_nums=None,
  54. iou_thrs=0.5,
  55. logger=None,
  56. use_legacy_coordinate=False):
  57. """Calculate recalls.
  58. Args:
  59. gts (list[ndarray]): a list of arrays of shape (n, 4)
  60. proposals (list[ndarray]): a list of arrays of shape (k, 4) or (k, 5)
  61. proposal_nums (int | Sequence[int]): Top N proposals to be evaluated.
  62. iou_thrs (float | Sequence[float]): IoU thresholds. Default: 0.5.
  63. logger (logging.Logger | str | None): The way to print the recall
  64. summary. See `mmengine.logging.print_log()` for details.
  65. Default: None.
  66. use_legacy_coordinate (bool): Whether use coordinate system
  67. in mmdet v1.x. "1" was added to both height and width
  68. which means w, h should be
  69. computed as 'x2 - x1 + 1` and 'y2 - y1 + 1'. Default: False.
  70. Returns:
  71. ndarray: recalls of different ious and proposal nums
  72. """
  73. img_num = len(gts)
  74. assert img_num == len(proposals)
  75. proposal_nums, iou_thrs = set_recall_param(proposal_nums, iou_thrs)
  76. all_ious = []
  77. for i in range(img_num):
  78. if proposals[i].ndim == 2 and proposals[i].shape[1] == 5:
  79. scores = proposals[i][:, 4]
  80. sort_idx = np.argsort(scores)[::-1]
  81. img_proposal = proposals[i][sort_idx, :]
  82. else:
  83. img_proposal = proposals[i]
  84. prop_num = min(img_proposal.shape[0], proposal_nums[-1])
  85. if gts[i] is None or gts[i].shape[0] == 0:
  86. ious = np.zeros((0, img_proposal.shape[0]), dtype=np.float32)
  87. else:
  88. ious = bbox_overlaps(
  89. gts[i],
  90. img_proposal[:prop_num, :4],
  91. use_legacy_coordinate=use_legacy_coordinate)
  92. all_ious.append(ious)
  93. all_ious = np.array(all_ious)
  94. recalls = _recalls(all_ious, proposal_nums, iou_thrs)
  95. print_recall_summary(recalls, proposal_nums, iou_thrs, logger=logger)
  96. return recalls
  97. def print_recall_summary(recalls,
  98. proposal_nums,
  99. iou_thrs,
  100. row_idxs=None,
  101. col_idxs=None,
  102. logger=None):
  103. """Print recalls in a table.
  104. Args:
  105. recalls (ndarray): calculated from `bbox_recalls`
  106. proposal_nums (ndarray or list): top N proposals
  107. iou_thrs (ndarray or list): iou thresholds
  108. row_idxs (ndarray): which rows(proposal nums) to print
  109. col_idxs (ndarray): which cols(iou thresholds) to print
  110. logger (logging.Logger | str | None): The way to print the recall
  111. summary. See `mmengine.logging.print_log()` for details.
  112. Default: None.
  113. """
  114. proposal_nums = np.array(proposal_nums, dtype=np.int32)
  115. iou_thrs = np.array(iou_thrs)
  116. if row_idxs is None:
  117. row_idxs = np.arange(proposal_nums.size)
  118. if col_idxs is None:
  119. col_idxs = np.arange(iou_thrs.size)
  120. row_header = [''] + iou_thrs[col_idxs].tolist()
  121. table_data = [row_header]
  122. for i, num in enumerate(proposal_nums[row_idxs]):
  123. row = [f'{val:.3f}' for val in recalls[row_idxs[i], col_idxs].tolist()]
  124. row.insert(0, num)
  125. table_data.append(row)
  126. table = AsciiTable(table_data)
  127. print_log('\n' + table.table, logger=logger)
  128. def plot_num_recall(recalls, proposal_nums):
  129. """Plot Proposal_num-Recalls curve.
  130. Args:
  131. recalls(ndarray or list): shape (k,)
  132. proposal_nums(ndarray or list): same shape as `recalls`
  133. """
  134. if isinstance(proposal_nums, np.ndarray):
  135. _proposal_nums = proposal_nums.tolist()
  136. else:
  137. _proposal_nums = proposal_nums
  138. if isinstance(recalls, np.ndarray):
  139. _recalls = recalls.tolist()
  140. else:
  141. _recalls = recalls
  142. import matplotlib.pyplot as plt
  143. f = plt.figure()
  144. plt.plot([0] + _proposal_nums, [0] + _recalls)
  145. plt.xlabel('Proposal num')
  146. plt.ylabel('Recall')
  147. plt.axis([0, proposal_nums.max(), 0, 1])
  148. f.show()
  149. def plot_iou_recall(recalls, iou_thrs):
  150. """Plot IoU-Recalls curve.
  151. Args:
  152. recalls(ndarray or list): shape (k,)
  153. iou_thrs(ndarray or list): same shape as `recalls`
  154. """
  155. if isinstance(iou_thrs, np.ndarray):
  156. _iou_thrs = iou_thrs.tolist()
  157. else:
  158. _iou_thrs = iou_thrs
  159. if isinstance(recalls, np.ndarray):
  160. _recalls = recalls.tolist()
  161. else:
  162. _recalls = recalls
  163. import matplotlib.pyplot as plt
  164. f = plt.figure()
  165. plt.plot(_iou_thrs + [1.0], _recalls + [0.])
  166. plt.xlabel('IoU')
  167. plt.ylabel('Recall')
  168. plt.axis([iou_thrs.min(), 1, 0, 1])
  169. f.show()