crowdhuman_metric.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import json
  4. import os.path as osp
  5. import tempfile
  6. from collections import OrderedDict
  7. from multiprocessing import Process, Queue
  8. from typing import Dict, List, Optional, Sequence, Union
  9. import numpy as np
  10. from mmengine.evaluator import BaseMetric
  11. from mmengine.fileio import dump, get_text, load
  12. from mmengine.logging import MMLogger
  13. from scipy.sparse import csr_matrix
  14. from scipy.sparse.csgraph import maximum_bipartite_matching
  15. from mmdet.evaluation.functional.bbox_overlaps import bbox_overlaps
  16. from mmdet.registry import METRICS
  17. PERSON_CLASSES = ['background', 'person']
  18. @METRICS.register_module()
  19. class CrowdHumanMetric(BaseMetric):
  20. """CrowdHuman evaluation metric.
  21. Evaluate Average Precision (AP), Miss Rate (MR) and Jaccard Index (JI)
  22. for detection tasks.
  23. Args:
  24. ann_file (str): Path to the annotation file.
  25. metric (str | List[str]): Metrics to be evaluated. Valid metrics
  26. include 'AP', 'MR' and 'JI'. Defaults to 'AP'.
  27. format_only (bool): Format the output results without perform
  28. evaluation. It is useful when you want to format the result
  29. to a specific format and submit it to the test server.
  30. Defaults to False.
  31. outfile_prefix (str, optional): The prefix of json files. It includes
  32. the file path and the prefix of filename, e.g., "a/b/prefix".
  33. If not specified, a temp file will be created. Defaults to None.
  34. file_client_args (dict, optional): Arguments to instantiate the
  35. corresponding backend in mmdet <= 3.0.0rc6. Defaults to None.
  36. backend_args (dict, optional): Arguments to instantiate the
  37. corresponding backend. Defaults to None.
  38. collect_device (str): Device name used for collecting results from
  39. different ranks during distributed training. Must be 'cpu' or
  40. 'gpu'. Defaults to 'cpu'.
  41. prefix (str, optional): The prefix that will be added in the metric
  42. names to disambiguate homonymous metrics of different evaluators.
  43. If prefix is not provided in the argument, self.default_prefix
  44. will be used instead. Defaults to None.
  45. eval_mode (int): Select the mode of evaluate. Valid mode include
  46. 0(just body box), 1(just head box) and 2(both of them).
  47. Defaults to 0.
  48. iou_thres (float): IoU threshold. Defaults to 0.5.
  49. compare_matching_method (str, optional): Matching method to compare
  50. the detection results with the ground_truth when compute 'AP'
  51. and 'MR'.Valid method include VOC and None(CALTECH). Default to
  52. None.
  53. mr_ref (str): Different parameter selection to calculate MR. Valid
  54. ref include CALTECH_-2 and CALTECH_-4. Defaults to CALTECH_-2.
  55. num_ji_process (int): The number of processes to evaluation JI.
  56. Defaults to 10.
  57. """
  58. default_prefix: Optional[str] = 'crowd_human'
  59. def __init__(self,
  60. ann_file: str,
  61. metric: Union[str, List[str]] = ['AP', 'MR', 'JI'],
  62. format_only: bool = False,
  63. outfile_prefix: Optional[str] = None,
  64. file_client_args: dict = None,
  65. backend_args: dict = None,
  66. collect_device: str = 'cpu',
  67. prefix: Optional[str] = None,
  68. eval_mode: int = 0,
  69. iou_thres: float = 0.5,
  70. compare_matching_method: Optional[str] = None,
  71. mr_ref: str = 'CALTECH_-2',
  72. num_ji_process: int = 10) -> None:
  73. super().__init__(collect_device=collect_device, prefix=prefix)
  74. self.ann_file = ann_file
  75. # crowdhuman evaluation metrics
  76. self.metrics = metric if isinstance(metric, list) else [metric]
  77. allowed_metrics = ['MR', 'AP', 'JI']
  78. for metric in self.metrics:
  79. if metric not in allowed_metrics:
  80. raise KeyError(f"metric should be one of 'MR', 'AP', 'JI',"
  81. f'but got {metric}.')
  82. self.format_only = format_only
  83. if self.format_only:
  84. assert outfile_prefix is not None, 'outfile_prefix must be not'
  85. 'None when format_only is True, otherwise the result files will'
  86. 'be saved to a temp directory which will be cleaned up at the end.'
  87. self.outfile_prefix = outfile_prefix
  88. self.backend_args = backend_args
  89. if file_client_args is not None:
  90. raise RuntimeError(
  91. 'The `file_client_args` is deprecated, '
  92. 'please use `backend_args` instead, please refer to'
  93. 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501
  94. )
  95. assert eval_mode in [0, 1, 2], \
  96. "Unknown eval mode. mr_ref should be one of '0', '1', '2'."
  97. assert compare_matching_method is None or \
  98. compare_matching_method == 'VOC', \
  99. 'The alternative compare_matching_method is VOC.' \
  100. 'This parameter defaults to CALTECH(None)'
  101. assert mr_ref == 'CALTECH_-2' or mr_ref == 'CALTECH_-4', \
  102. "mr_ref should be one of 'CALTECH_-2', 'CALTECH_-4'."
  103. self.eval_mode = eval_mode
  104. self.iou_thres = iou_thres
  105. self.compare_matching_method = compare_matching_method
  106. self.mr_ref = mr_ref
  107. self.num_ji_process = num_ji_process
  108. @staticmethod
  109. def results2json(results: Sequence[dict], outfile_prefix: str) -> str:
  110. """Dump the detection results to a json file."""
  111. result_file_path = f'{outfile_prefix}.json'
  112. bbox_json_results = []
  113. for i, result in enumerate(results):
  114. ann, pred = result
  115. dump_dict = dict()
  116. dump_dict['ID'] = ann['ID']
  117. dump_dict['width'] = ann['width']
  118. dump_dict['height'] = ann['height']
  119. dtboxes = []
  120. bboxes = pred.tolist()
  121. for _, single_bbox in enumerate(bboxes):
  122. temp_dict = dict()
  123. x1, y1, x2, y2, score = single_bbox
  124. temp_dict['box'] = [x1, y1, x2 - x1, y2 - y1]
  125. temp_dict['score'] = score
  126. temp_dict['tag'] = 1
  127. dtboxes.append(temp_dict)
  128. dump_dict['dtboxes'] = dtboxes
  129. bbox_json_results.append(dump_dict)
  130. dump(bbox_json_results, result_file_path)
  131. return result_file_path
  132. def process(self, data_batch: Sequence[dict],
  133. data_samples: Sequence[dict]) -> None:
  134. """Process one batch of data samples and predictions. The processed
  135. results should be stored in ``self.results``, which will be used to
  136. compute the metrics when all batches have been processed.
  137. Args:
  138. data_batch (dict): A batch of data from the dataloader.
  139. data_samples (Sequence[dict]): A batch of data samples that
  140. contain annotations and predictions.
  141. """
  142. for data_sample in data_samples:
  143. ann = dict()
  144. ann['ID'] = data_sample['img_id']
  145. ann['width'] = data_sample['ori_shape'][1]
  146. ann['height'] = data_sample['ori_shape'][0]
  147. pred_bboxes = data_sample['pred_instances']['bboxes'].cpu().numpy()
  148. pred_scores = data_sample['pred_instances']['scores'].cpu().numpy()
  149. pred_bbox_scores = np.hstack(
  150. [pred_bboxes, pred_scores.reshape((-1, 1))])
  151. self.results.append((ann, pred_bbox_scores))
  152. def compute_metrics(self, results: list) -> Dict[str, float]:
  153. """Compute the metrics from processed results.
  154. Args:
  155. results (list): The processed results of each batch.
  156. Returns:
  157. eval_results(Dict[str, float]): The computed metrics.
  158. The keys are the names of the metrics, and the values
  159. are corresponding results.
  160. """
  161. logger: MMLogger = MMLogger.get_current_instance()
  162. tmp_dir = None
  163. if self.outfile_prefix is None:
  164. tmp_dir = tempfile.TemporaryDirectory()
  165. outfile_prefix = osp.join(tmp_dir.name, 'result')
  166. else:
  167. outfile_prefix = self.outfile_prefix
  168. # convert predictions to coco format and dump to json file
  169. result_file = self.results2json(results, outfile_prefix)
  170. eval_results = OrderedDict()
  171. if self.format_only:
  172. logger.info(f'results are saved in {osp.dirname(outfile_prefix)}')
  173. return eval_results
  174. # load evaluation samples
  175. eval_samples = self.load_eval_samples(result_file)
  176. if 'AP' in self.metrics or 'MR' in self.metrics:
  177. score_list = self.compare(eval_samples)
  178. gt_num = sum([eval_samples[i].gt_num for i in eval_samples])
  179. ign_num = sum([eval_samples[i].ign_num for i in eval_samples])
  180. gt_num = gt_num - ign_num
  181. img_num = len(eval_samples)
  182. for metric in self.metrics:
  183. logger.info(f'Evaluating {metric}...')
  184. if metric == 'AP':
  185. AP = self.eval_ap(score_list, gt_num, img_num)
  186. eval_results['mAP'] = float(f'{round(AP, 4)}')
  187. if metric == 'MR':
  188. MR = self.eval_mr(score_list, gt_num, img_num)
  189. eval_results['mMR'] = float(f'{round(MR, 4)}')
  190. if metric == 'JI':
  191. JI = self.eval_ji(eval_samples)
  192. eval_results['JI'] = float(f'{round(JI, 4)}')
  193. if tmp_dir is not None:
  194. tmp_dir.cleanup()
  195. return eval_results
  196. def load_eval_samples(self, result_file):
  197. """Load data from annotations file and detection results.
  198. Args:
  199. result_file (str): The file path of the saved detection results.
  200. Returns:
  201. Dict[Image]: The detection result packaged by Image
  202. """
  203. gt_str = get_text(
  204. self.ann_file, backend_args=self.backend_args).strip().split('\n')
  205. gt_records = [json.loads(line) for line in gt_str]
  206. pred_records = load(result_file, backend_args=self.backend_args)
  207. eval_samples = dict()
  208. for gt_record, pred_record in zip(gt_records, pred_records):
  209. assert gt_record['ID'] == pred_record['ID'], \
  210. 'please set val_dataloader.sampler.shuffle=False and try again'
  211. eval_samples[pred_record['ID']] = Image(self.eval_mode)
  212. eval_samples[pred_record['ID']].load(gt_record, 'box', None,
  213. PERSON_CLASSES, True)
  214. eval_samples[pred_record['ID']].load(pred_record, 'box', None,
  215. PERSON_CLASSES, False)
  216. eval_samples[pred_record['ID']].clip_all_boader()
  217. return eval_samples
  218. def compare(self, samples):
  219. """Match the detection results with the ground_truth.
  220. Args:
  221. samples (dict[Image]): The detection result packaged by Image.
  222. Returns:
  223. score_list(list[tuple[ndarray, int, str]]): Matching result.
  224. a list of tuples (dtbox, label, imgID) in the descending
  225. sort of dtbox.score.
  226. """
  227. score_list = list()
  228. for id in samples:
  229. if self.compare_matching_method == 'VOC':
  230. result = samples[id].compare_voc(self.iou_thres)
  231. else:
  232. result = samples[id].compare_caltech(self.iou_thres)
  233. score_list.extend(result)
  234. # In the descending sort of dtbox score.
  235. score_list.sort(key=lambda x: x[0][-1], reverse=True)
  236. return score_list
  237. @staticmethod
  238. def eval_ap(score_list, gt_num, img_num):
  239. """Evaluate by average precision.
  240. Args:
  241. score_list(list[tuple[ndarray, int, str]]): Matching result.
  242. a list of tuples (dtbox, label, imgID) in the descending
  243. sort of dtbox.score.
  244. gt_num(int): The number of gt boxes in the entire dataset.
  245. img_num(int): The number of images in the entire dataset.
  246. Returns:
  247. ap(float): result of average precision.
  248. """
  249. # calculate general ap score
  250. def _calculate_map(_recall, _precision):
  251. assert len(_recall) == len(_precision)
  252. area = 0
  253. for k in range(1, len(_recall)):
  254. delta_h = (_precision[k - 1] + _precision[k]) / 2
  255. delta_w = _recall[k] - _recall[k - 1]
  256. area += delta_w * delta_h
  257. return area
  258. tp, fp = 0.0, 0.0
  259. rpX, rpY = list(), list()
  260. fpn = []
  261. recalln = []
  262. thr = []
  263. fppi = []
  264. for i, item in enumerate(score_list):
  265. if item[1] == 1:
  266. tp += 1.0
  267. elif item[1] == 0:
  268. fp += 1.0
  269. fn = gt_num - tp
  270. recall = tp / (tp + fn)
  271. precision = tp / (tp + fp)
  272. rpX.append(recall)
  273. rpY.append(precision)
  274. fpn.append(fp)
  275. recalln.append(tp)
  276. thr.append(item[0][-1])
  277. fppi.append(fp / img_num)
  278. ap = _calculate_map(rpX, rpY)
  279. return ap
  280. def eval_mr(self, score_list, gt_num, img_num):
  281. """Evaluate by Caltech-style log-average miss rate.
  282. Args:
  283. score_list(list[tuple[ndarray, int, str]]): Matching result.
  284. a list of tuples (dtbox, label, imgID) in the descending
  285. sort of dtbox.score.
  286. gt_num(int): The number of gt boxes in the entire dataset.
  287. img_num(int): The number of image in the entire dataset.
  288. Returns:
  289. mr(float): result of miss rate.
  290. """
  291. # find greater_than
  292. def _find_gt(lst, target):
  293. for idx, _item in enumerate(lst):
  294. if _item >= target:
  295. return idx
  296. return len(lst) - 1
  297. if self.mr_ref == 'CALTECH_-2':
  298. # CALTECH_MRREF_2: anchor points (from 10^-2 to 1) as in
  299. # P.Dollar's paper
  300. ref = [
  301. 0.0100, 0.0178, 0.03160, 0.0562, 0.1000, 0.1778, 0.3162,
  302. 0.5623, 1.000
  303. ]
  304. else:
  305. # CALTECH_MRREF_4: anchor points (from 10^-4 to 1) as in
  306. # S.Zhang's paper
  307. ref = [
  308. 0.0001, 0.0003, 0.00100, 0.0032, 0.0100, 0.0316, 0.1000,
  309. 0.3162, 1.000
  310. ]
  311. tp, fp = 0.0, 0.0
  312. fppiX, fppiY = list(), list()
  313. for i, item in enumerate(score_list):
  314. if item[1] == 1:
  315. tp += 1.0
  316. elif item[1] == 0:
  317. fp += 1.0
  318. fn = gt_num - tp
  319. recall = tp / (tp + fn)
  320. missrate = 1.0 - recall
  321. fppi = fp / img_num
  322. fppiX.append(fppi)
  323. fppiY.append(missrate)
  324. score = list()
  325. for pos in ref:
  326. argmin = _find_gt(fppiX, pos)
  327. if argmin >= 0:
  328. score.append(fppiY[argmin])
  329. score = np.array(score)
  330. mr = np.exp(np.log(score).mean())
  331. return mr
  332. def eval_ji(self, samples):
  333. """Evaluate by JI using multi_process.
  334. Args:
  335. samples(Dict[str, Image]): The detection result packaged by Image.
  336. Returns:
  337. ji(float): result of jaccard index.
  338. """
  339. import math
  340. res_line = []
  341. res_ji = []
  342. for i in range(10):
  343. score_thr = 1e-1 * i
  344. total = len(samples)
  345. stride = math.ceil(total / self.num_ji_process)
  346. result_queue = Queue(10000)
  347. results, procs = [], []
  348. records = list(samples.items())
  349. for i in range(self.num_ji_process):
  350. start = i * stride
  351. end = np.min([start + stride, total])
  352. sample_data = dict(records[start:end])
  353. p = Process(
  354. target=self.compute_ji_with_ignore,
  355. args=(result_queue, sample_data, score_thr))
  356. p.start()
  357. procs.append(p)
  358. for i in range(total):
  359. t = result_queue.get()
  360. results.append(t)
  361. for p in procs:
  362. p.join()
  363. line, mean_ratio = self.gather(results)
  364. line = 'score_thr:{:.1f}, {}'.format(score_thr, line)
  365. res_line.append(line)
  366. res_ji.append(mean_ratio)
  367. return max(res_ji)
  368. def compute_ji_with_ignore(self, result_queue, dt_result, score_thr):
  369. """Compute JI with ignore.
  370. Args:
  371. result_queue(Queue): The Queue for save compute result when
  372. multi_process.
  373. dt_result(dict[Image]): Detection result packaged by Image.
  374. score_thr(float): The threshold of detection score.
  375. Returns:
  376. dict: compute result.
  377. """
  378. for ID, record in dt_result.items():
  379. gt_boxes = record.gt_boxes
  380. dt_boxes = record.dt_boxes
  381. keep = dt_boxes[:, -1] > score_thr
  382. dt_boxes = dt_boxes[keep][:, :-1]
  383. gt_tag = np.array(gt_boxes[:, -1] != -1)
  384. matches = self.compute_ji_matching(dt_boxes, gt_boxes[gt_tag, :4])
  385. # get the unmatched_indices
  386. matched_indices = np.array([j for (j, _) in matches])
  387. unmatched_indices = list(
  388. set(np.arange(dt_boxes.shape[0])) - set(matched_indices))
  389. num_ignore_dt = self.get_ignores(dt_boxes[unmatched_indices],
  390. gt_boxes[~gt_tag, :4])
  391. matched_indices = np.array([j for (_, j) in matches])
  392. unmatched_indices = list(
  393. set(np.arange(gt_boxes[gt_tag].shape[0])) -
  394. set(matched_indices))
  395. num_ignore_gt = self.get_ignores(
  396. gt_boxes[gt_tag][unmatched_indices], gt_boxes[~gt_tag, :4])
  397. # compute results
  398. eps = 1e-6
  399. k = len(matches)
  400. m = gt_tag.sum() - num_ignore_gt
  401. n = dt_boxes.shape[0] - num_ignore_dt
  402. ratio = k / (m + n - k + eps)
  403. recall = k / (m + eps)
  404. cover = k / (n + eps)
  405. noise = 1 - cover
  406. result_dict = dict(
  407. ratio=ratio,
  408. recall=recall,
  409. cover=cover,
  410. noise=noise,
  411. k=k,
  412. m=m,
  413. n=n)
  414. result_queue.put_nowait(result_dict)
  415. @staticmethod
  416. def gather(results):
  417. """Integrate test results."""
  418. assert len(results)
  419. img_num = 0
  420. for result in results:
  421. if result['n'] != 0 or result['m'] != 0:
  422. img_num += 1
  423. mean_ratio = np.sum([rb['ratio'] for rb in results]) / img_num
  424. valids = np.sum([rb['k'] for rb in results])
  425. total = np.sum([rb['n'] for rb in results])
  426. gtn = np.sum([rb['m'] for rb in results])
  427. line = 'mean_ratio:{:.4f}, valids:{}, total:{}, gtn:{}'\
  428. .format(mean_ratio, valids, total, gtn)
  429. return line, mean_ratio
  430. def compute_ji_matching(self, dt_boxes, gt_boxes):
  431. """Match the annotation box for each detection box.
  432. Args:
  433. dt_boxes(ndarray): Detection boxes.
  434. gt_boxes(ndarray): Ground_truth boxes.
  435. Returns:
  436. matches_(list[tuple[int, int]]): Match result.
  437. """
  438. assert dt_boxes.shape[-1] > 3 and gt_boxes.shape[-1] > 3
  439. if dt_boxes.shape[0] < 1 or gt_boxes.shape[0] < 1:
  440. return list()
  441. ious = bbox_overlaps(dt_boxes, gt_boxes, mode='iou')
  442. input_ = copy.deepcopy(ious)
  443. input_[input_ < self.iou_thres] = 0
  444. match_scipy = maximum_bipartite_matching(
  445. csr_matrix(input_), perm_type='column')
  446. matches_ = []
  447. for i in range(len(match_scipy)):
  448. if match_scipy[i] != -1:
  449. matches_.append((i, int(match_scipy[i])))
  450. return matches_
  451. def get_ignores(self, dt_boxes, gt_boxes):
  452. """Get the number of ignore bboxes."""
  453. if gt_boxes.size:
  454. ioas = bbox_overlaps(dt_boxes, gt_boxes, mode='iof')
  455. ioas = np.max(ioas, axis=1)
  456. rows = np.where(ioas > self.iou_thres)[0]
  457. return len(rows)
  458. else:
  459. return 0
  460. class Image(object):
  461. """Data structure for evaluation of CrowdHuman.
  462. Note:
  463. This implementation is modified from https://github.com/Purkialo/
  464. CrowdDet/blob/master/lib/evaluate/APMRToolkits/image.py
  465. Args:
  466. mode (int): Select the mode of evaluate. Valid mode include
  467. 0(just body box), 1(just head box) and 2(both of them).
  468. Defaults to 0.
  469. """
  470. def __init__(self, mode):
  471. self.ID = None
  472. self.width = None
  473. self.height = None
  474. self.dt_boxes = None
  475. self.gt_boxes = None
  476. self.eval_mode = mode
  477. self.ign_num = None
  478. self.gt_num = None
  479. self.dt_num = None
  480. def load(self, record, body_key, head_key, class_names, gt_flag):
  481. """Loading information for evaluation.
  482. Args:
  483. record (dict): Label information or test results.
  484. The format might look something like this:
  485. {
  486. 'ID': '273271,c9db000d5146c15',
  487. 'gtboxes': [
  488. {'fbox': [72, 202, 163, 503], 'tag': 'person', ...},
  489. {'fbox': [199, 180, 144, 499], 'tag': 'person', ...},
  490. ...
  491. ]
  492. }
  493. or:
  494. {
  495. 'ID': '273271,c9db000d5146c15',
  496. 'width': 800,
  497. 'height': 1067,
  498. 'dtboxes': [
  499. {
  500. 'box': [306.22, 205.95, 164.05, 394.04],
  501. 'score': 0.99,
  502. 'tag': 1
  503. },
  504. {
  505. 'box': [403.60, 178.66, 157.15, 421.33],
  506. 'score': 0.99,
  507. 'tag': 1
  508. },
  509. ...
  510. ]
  511. }
  512. body_key (str, None): key of detection body box.
  513. Valid when loading detection results and self.eval_mode!=1.
  514. head_key (str, None): key of detection head box.
  515. Valid when loading detection results and self.eval_mode!=0.
  516. class_names (list[str]):class names of data set.
  517. Defaults to ['background', 'person'].
  518. gt_flag (bool): Indicate whether record is ground truth
  519. or predicting the outcome.
  520. """
  521. if 'ID' in record and self.ID is None:
  522. self.ID = record['ID']
  523. if 'width' in record and self.width is None:
  524. self.width = record['width']
  525. if 'height' in record and self.height is None:
  526. self.height = record['height']
  527. if gt_flag:
  528. self.gt_num = len(record['gtboxes'])
  529. body_bbox, head_bbox = self.load_gt_boxes(record, 'gtboxes',
  530. class_names)
  531. if self.eval_mode == 0:
  532. self.gt_boxes = body_bbox
  533. self.ign_num = (body_bbox[:, -1] == -1).sum()
  534. elif self.eval_mode == 1:
  535. self.gt_boxes = head_bbox
  536. self.ign_num = (head_bbox[:, -1] == -1).sum()
  537. else:
  538. gt_tag = np.array([
  539. body_bbox[i, -1] != -1 and head_bbox[i, -1] != -1
  540. for i in range(len(body_bbox))
  541. ])
  542. self.ign_num = (gt_tag == 0).sum()
  543. self.gt_boxes = np.hstack(
  544. (body_bbox[:, :-1], head_bbox[:, :-1],
  545. gt_tag.reshape(-1, 1)))
  546. if not gt_flag:
  547. self.dt_num = len(record['dtboxes'])
  548. if self.eval_mode == 0:
  549. self.dt_boxes = self.load_det_boxes(record, 'dtboxes',
  550. body_key, 'score')
  551. elif self.eval_mode == 1:
  552. self.dt_boxes = self.load_det_boxes(record, 'dtboxes',
  553. head_key, 'score')
  554. else:
  555. body_dtboxes = self.load_det_boxes(record, 'dtboxes', body_key,
  556. 'score')
  557. head_dtboxes = self.load_det_boxes(record, 'dtboxes', head_key,
  558. 'score')
  559. self.dt_boxes = np.hstack((body_dtboxes, head_dtboxes))
  560. @staticmethod
  561. def load_gt_boxes(dict_input, key_name, class_names):
  562. """load ground_truth and transform [x, y, w, h] to [x1, y1, x2, y2]"""
  563. assert key_name in dict_input
  564. if len(dict_input[key_name]) < 1:
  565. return np.empty([0, 5])
  566. head_bbox = []
  567. body_bbox = []
  568. for rb in dict_input[key_name]:
  569. if rb['tag'] in class_names:
  570. body_tag = class_names.index(rb['tag'])
  571. head_tag = copy.deepcopy(body_tag)
  572. else:
  573. body_tag = -1
  574. head_tag = -1
  575. if 'extra' in rb:
  576. if 'ignore' in rb['extra']:
  577. if rb['extra']['ignore'] != 0:
  578. body_tag = -1
  579. head_tag = -1
  580. if 'head_attr' in rb:
  581. if 'ignore' in rb['head_attr']:
  582. if rb['head_attr']['ignore'] != 0:
  583. head_tag = -1
  584. head_bbox.append(np.hstack((rb['hbox'], head_tag)))
  585. body_bbox.append(np.hstack((rb['fbox'], body_tag)))
  586. head_bbox = np.array(head_bbox)
  587. head_bbox[:, 2:4] += head_bbox[:, :2]
  588. body_bbox = np.array(body_bbox)
  589. body_bbox[:, 2:4] += body_bbox[:, :2]
  590. return body_bbox, head_bbox
  591. @staticmethod
  592. def load_det_boxes(dict_input, key_name, key_box, key_score, key_tag=None):
  593. """load detection boxes."""
  594. assert key_name in dict_input
  595. if len(dict_input[key_name]) < 1:
  596. return np.empty([0, 5])
  597. else:
  598. assert key_box in dict_input[key_name][0]
  599. if key_score:
  600. assert key_score in dict_input[key_name][0]
  601. if key_tag:
  602. assert key_tag in dict_input[key_name][0]
  603. if key_score:
  604. if key_tag:
  605. bboxes = np.vstack([
  606. np.hstack((rb[key_box], rb[key_score], rb[key_tag]))
  607. for rb in dict_input[key_name]
  608. ])
  609. else:
  610. bboxes = np.vstack([
  611. np.hstack((rb[key_box], rb[key_score]))
  612. for rb in dict_input[key_name]
  613. ])
  614. else:
  615. if key_tag:
  616. bboxes = np.vstack([
  617. np.hstack((rb[key_box], rb[key_tag]))
  618. for rb in dict_input[key_name]
  619. ])
  620. else:
  621. bboxes = np.vstack(
  622. [rb[key_box] for rb in dict_input[key_name]])
  623. bboxes[:, 2:4] += bboxes[:, :2]
  624. return bboxes
  625. def clip_all_boader(self):
  626. """Make sure boxes are within the image range."""
  627. def _clip_boundary(boxes, height, width):
  628. assert boxes.shape[-1] >= 4
  629. boxes[:, 0] = np.minimum(np.maximum(boxes[:, 0], 0), width - 1)
  630. boxes[:, 1] = np.minimum(np.maximum(boxes[:, 1], 0), height - 1)
  631. boxes[:, 2] = np.maximum(np.minimum(boxes[:, 2], width), 0)
  632. boxes[:, 3] = np.maximum(np.minimum(boxes[:, 3], height), 0)
  633. return boxes
  634. assert self.dt_boxes.shape[-1] >= 4
  635. assert self.gt_boxes.shape[-1] >= 4
  636. assert self.width is not None and self.height is not None
  637. if self.eval_mode == 2:
  638. self.dt_boxes[:, :4] = _clip_boundary(self.dt_boxes[:, :4],
  639. self.height, self.width)
  640. self.gt_boxes[:, :4] = _clip_boundary(self.gt_boxes[:, :4],
  641. self.height, self.width)
  642. self.dt_boxes[:, 4:8] = _clip_boundary(self.dt_boxes[:, 4:8],
  643. self.height, self.width)
  644. self.gt_boxes[:, 4:8] = _clip_boundary(self.gt_boxes[:, 4:8],
  645. self.height, self.width)
  646. else:
  647. self.dt_boxes = _clip_boundary(self.dt_boxes, self.height,
  648. self.width)
  649. self.gt_boxes = _clip_boundary(self.gt_boxes, self.height,
  650. self.width)
  651. def compare_voc(self, thres):
  652. """Match the detection results with the ground_truth by VOC.
  653. Args:
  654. thres (float): IOU threshold.
  655. Returns:
  656. score_list(list[tuple[ndarray, int, str]]): Matching result.
  657. a list of tuples (dtbox, label, imgID) in the descending
  658. sort of dtbox.score.
  659. """
  660. if self.dt_boxes is None:
  661. return list()
  662. dtboxes = self.dt_boxes
  663. gtboxes = self.gt_boxes if self.gt_boxes is not None else list()
  664. dtboxes.sort(key=lambda x: x.score, reverse=True)
  665. gtboxes.sort(key=lambda x: x.ign)
  666. score_list = list()
  667. for i, dt in enumerate(dtboxes):
  668. maxpos = -1
  669. maxiou = thres
  670. for j, gt in enumerate(gtboxes):
  671. overlap = dt.iou(gt)
  672. if overlap > maxiou:
  673. maxiou = overlap
  674. maxpos = j
  675. if maxpos >= 0:
  676. if gtboxes[maxpos].ign == 0:
  677. gtboxes[maxpos].matched = 1
  678. dtboxes[i].matched = 1
  679. score_list.append((dt, self.ID))
  680. else:
  681. dtboxes[i].matched = -1
  682. else:
  683. dtboxes[i].matched = 0
  684. score_list.append((dt, self.ID))
  685. return score_list
  686. def compare_caltech(self, thres):
  687. """Match the detection results with the ground_truth by Caltech
  688. matching strategy.
  689. Args:
  690. thres (float): IOU threshold.
  691. Returns:
  692. score_list(list[tuple[ndarray, int, str]]): Matching result.
  693. a list of tuples (dtbox, label, imgID) in the descending
  694. sort of dtbox.score.
  695. """
  696. if self.dt_boxes is None or self.gt_boxes is None:
  697. return list()
  698. dtboxes = self.dt_boxes if self.dt_boxes is not None else list()
  699. gtboxes = self.gt_boxes if self.gt_boxes is not None else list()
  700. dt_matched = np.zeros(dtboxes.shape[0])
  701. gt_matched = np.zeros(gtboxes.shape[0])
  702. dtboxes = np.array(sorted(dtboxes, key=lambda x: x[-1], reverse=True))
  703. gtboxes = np.array(sorted(gtboxes, key=lambda x: x[-1], reverse=True))
  704. if len(dtboxes):
  705. overlap_iou = bbox_overlaps(dtboxes, gtboxes, mode='iou')
  706. overlap_ioa = bbox_overlaps(dtboxes, gtboxes, mode='iof')
  707. else:
  708. return list()
  709. score_list = list()
  710. for i, dt in enumerate(dtboxes):
  711. maxpos = -1
  712. maxiou = thres
  713. for j, gt in enumerate(gtboxes):
  714. if gt_matched[j] == 1:
  715. continue
  716. if gt[-1] > 0:
  717. overlap = overlap_iou[i][j]
  718. if overlap > maxiou:
  719. maxiou = overlap
  720. maxpos = j
  721. else:
  722. if maxpos >= 0:
  723. break
  724. else:
  725. overlap = overlap_ioa[i][j]
  726. if overlap > thres:
  727. maxiou = overlap
  728. maxpos = j
  729. if maxpos >= 0:
  730. if gtboxes[maxpos, -1] > 0:
  731. gt_matched[maxpos] = 1
  732. dt_matched[i] = 1
  733. score_list.append((dt, 1, self.ID))
  734. else:
  735. dt_matched[i] = -1
  736. else:
  737. dt_matched[i] = 0
  738. score_list.append((dt, 0, self.ID))
  739. return score_list