test_openimages_metric.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import unittest
  3. import numpy as np
  4. import torch
  5. from mmdet.datasets import OpenImagesDataset
  6. from mmdet.evaluation import OpenImagesMetric
  7. from mmdet.utils import register_all_modules
  8. class TestOpenImagesMetric(unittest.TestCase):
  9. def _create_dummy_results(self):
  10. bboxes = np.array([[23.2172, 31.7541, 987.3413, 357.8443],
  11. [100, 120, 130, 150], [150, 160, 190, 200],
  12. [250, 260, 350, 360]])
  13. scores = np.array([1.0, 0.98, 0.96, 0.95])
  14. labels = np.array([0, 0, 0, 0])
  15. return dict(
  16. bboxes=torch.from_numpy(bboxes),
  17. scores=torch.from_numpy(scores),
  18. labels=torch.from_numpy(labels))
  19. def test_init(self):
  20. # test invalid iou_thrs
  21. with self.assertRaises(AssertionError):
  22. OpenImagesMetric(iou_thrs={'a', 0.5}, ioa_thrs={'b', 0.5})
  23. # test ioa and iou_thrs length not equal
  24. with self.assertRaises(AssertionError):
  25. OpenImagesMetric(iou_thrs=[0.5, 0.75], ioa_thrs=[0.5])
  26. metric = OpenImagesMetric(iou_thrs=0.6)
  27. self.assertEqual(metric.iou_thrs, [0.6])
  28. def test_eval(self):
  29. register_all_modules()
  30. dataset = OpenImagesDataset(
  31. data_root='tests/data/OpenImages/',
  32. ann_file='annotations/oidv6-train-annotations-bbox.csv',
  33. data_prefix=dict(img='OpenImages/train/'),
  34. label_file='annotations/class-descriptions-boxable.csv',
  35. hierarchy_file='annotations/bbox_labels_600_hierarchy.json',
  36. meta_file='annotations/image-metas.pkl',
  37. pipeline=[
  38. dict(type='LoadAnnotations', with_bbox=True),
  39. dict(
  40. type='PackDetInputs',
  41. meta_keys=('img_id', 'img_path', 'instances'))
  42. ])
  43. dataset.full_init()
  44. data_sample = dataset[0]['data_samples'].to_dict()
  45. data_sample['pred_instances'] = self._create_dummy_results()
  46. metric = OpenImagesMetric()
  47. metric.dataset_meta = dataset.metainfo
  48. metric.process({}, [data_sample])
  49. results = metric.evaluate(size=len(dataset))
  50. targets = {'openimages/AP50': 1.0, 'openimages/mAP': 1.0}
  51. self.assertDictEqual(results, targets)
  52. # test multi-threshold
  53. metric = OpenImagesMetric(iou_thrs=[0.1, 0.5], ioa_thrs=[0.1, 0.5])
  54. metric.dataset_meta = dataset.metainfo
  55. metric.process({}, [data_sample])
  56. results = metric.evaluate(size=len(dataset))
  57. targets = {
  58. 'openimages/AP10': 1.0,
  59. 'openimages/AP50': 1.0,
  60. 'openimages/mAP': 1.0
  61. }
  62. self.assertDictEqual(results, targets)