test_coco_metric.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. import os.path as osp
  2. import tempfile
  3. from unittest import TestCase
  4. import numpy as np
  5. import pycocotools.mask as mask_util
  6. import torch
  7. from mmengine.fileio import dump
  8. from mmdet.evaluation import CocoMetric
  9. class TestCocoMetric(TestCase):
  10. def _create_dummy_coco_json(self, json_name):
  11. dummy_mask = np.zeros((10, 10), order='F', dtype=np.uint8)
  12. dummy_mask[:5, :5] = 1
  13. rle_mask = mask_util.encode(dummy_mask)
  14. rle_mask['counts'] = rle_mask['counts'].decode('utf-8')
  15. image = {
  16. 'id': 0,
  17. 'width': 640,
  18. 'height': 640,
  19. 'file_name': 'fake_name.jpg',
  20. }
  21. annotation_1 = {
  22. 'id': 1,
  23. 'image_id': 0,
  24. 'category_id': 0,
  25. 'area': 400,
  26. 'bbox': [50, 60, 20, 20],
  27. 'iscrowd': 0,
  28. 'segmentation': rle_mask,
  29. }
  30. annotation_2 = {
  31. 'id': 2,
  32. 'image_id': 0,
  33. 'category_id': 0,
  34. 'area': 900,
  35. 'bbox': [100, 120, 30, 30],
  36. 'iscrowd': 0,
  37. 'segmentation': rle_mask,
  38. }
  39. annotation_3 = {
  40. 'id': 3,
  41. 'image_id': 0,
  42. 'category_id': 1,
  43. 'area': 1600,
  44. 'bbox': [150, 160, 40, 40],
  45. 'iscrowd': 0,
  46. 'segmentation': rle_mask,
  47. }
  48. annotation_4 = {
  49. 'id': 4,
  50. 'image_id': 0,
  51. 'category_id': 0,
  52. 'area': 10000,
  53. 'bbox': [250, 260, 100, 100],
  54. 'iscrowd': 0,
  55. 'segmentation': rle_mask,
  56. }
  57. categories = [
  58. {
  59. 'id': 0,
  60. 'name': 'car',
  61. 'supercategory': 'car',
  62. },
  63. {
  64. 'id': 1,
  65. 'name': 'bicycle',
  66. 'supercategory': 'bicycle',
  67. },
  68. ]
  69. fake_json = {
  70. 'images': [image],
  71. 'annotations':
  72. [annotation_1, annotation_2, annotation_3, annotation_4],
  73. 'categories': categories
  74. }
  75. dump(fake_json, json_name)
  76. def _create_dummy_results(self):
  77. bboxes = np.array([[50, 60, 70, 80], [100, 120, 130, 150],
  78. [150, 160, 190, 200], [250, 260, 350, 360]])
  79. scores = np.array([1.0, 0.98, 0.96, 0.95])
  80. labels = np.array([0, 0, 1, 0])
  81. dummy_mask = np.zeros((4, 10, 10), dtype=np.uint8)
  82. dummy_mask[:, :5, :5] = 1
  83. return dict(
  84. bboxes=torch.from_numpy(bboxes),
  85. scores=torch.from_numpy(scores),
  86. labels=torch.from_numpy(labels),
  87. masks=torch.from_numpy(dummy_mask))
  88. def setUp(self):
  89. self.tmp_dir = tempfile.TemporaryDirectory()
  90. def tearDown(self):
  91. self.tmp_dir.cleanup()
  92. def test_init(self):
  93. fake_json_file = osp.join(self.tmp_dir.name, 'fake_data.json')
  94. self._create_dummy_coco_json(fake_json_file)
  95. with self.assertRaisesRegex(KeyError, 'metric should be one of'):
  96. CocoMetric(ann_file=fake_json_file, metric='unknown')
  97. def test_evaluate(self):
  98. # create dummy data
  99. fake_json_file = osp.join(self.tmp_dir.name, 'fake_data.json')
  100. self._create_dummy_coco_json(fake_json_file)
  101. dummy_pred = self._create_dummy_results()
  102. # test single coco dataset evaluation
  103. coco_metric = CocoMetric(
  104. ann_file=fake_json_file,
  105. classwise=False,
  106. outfile_prefix=f'{self.tmp_dir.name}/test')
  107. coco_metric.dataset_meta = dict(classes=['car', 'bicycle'])
  108. coco_metric.process(
  109. {},
  110. [dict(pred_instances=dummy_pred, img_id=0, ori_shape=(640, 640))])
  111. eval_results = coco_metric.evaluate(size=1)
  112. target = {
  113. 'coco/bbox_mAP': 1.0,
  114. 'coco/bbox_mAP_50': 1.0,
  115. 'coco/bbox_mAP_75': 1.0,
  116. 'coco/bbox_mAP_s': 1.0,
  117. 'coco/bbox_mAP_m': 1.0,
  118. 'coco/bbox_mAP_l': 1.0,
  119. }
  120. self.assertDictEqual(eval_results, target)
  121. self.assertTrue(
  122. osp.isfile(osp.join(self.tmp_dir.name, 'test.bbox.json')))
  123. # test box and segm coco dataset evaluation
  124. coco_metric = CocoMetric(
  125. ann_file=fake_json_file,
  126. metric=['bbox', 'segm'],
  127. classwise=False,
  128. outfile_prefix=f'{self.tmp_dir.name}/test')
  129. coco_metric.dataset_meta = dict(classes=['car', 'bicycle'])
  130. coco_metric.process(
  131. {},
  132. [dict(pred_instances=dummy_pred, img_id=0, ori_shape=(640, 640))])
  133. eval_results = coco_metric.evaluate(size=1)
  134. target = {
  135. 'coco/bbox_mAP': 1.0,
  136. 'coco/bbox_mAP_50': 1.0,
  137. 'coco/bbox_mAP_75': 1.0,
  138. 'coco/bbox_mAP_s': 1.0,
  139. 'coco/bbox_mAP_m': 1.0,
  140. 'coco/bbox_mAP_l': 1.0,
  141. 'coco/segm_mAP': 1.0,
  142. 'coco/segm_mAP_50': 1.0,
  143. 'coco/segm_mAP_75': 1.0,
  144. 'coco/segm_mAP_s': 1.0,
  145. 'coco/segm_mAP_m': 1.0,
  146. 'coco/segm_mAP_l': 1.0,
  147. }
  148. self.assertDictEqual(eval_results, target)
  149. self.assertTrue(
  150. osp.isfile(osp.join(self.tmp_dir.name, 'test.bbox.json')))
  151. self.assertTrue(
  152. osp.isfile(osp.join(self.tmp_dir.name, 'test.segm.json')))
  153. # test invalid custom metric_items
  154. with self.assertRaisesRegex(KeyError,
  155. 'metric item "invalid" is not supported'):
  156. coco_metric = CocoMetric(
  157. ann_file=fake_json_file, metric_items=['invalid'])
  158. coco_metric.dataset_meta = dict(classes=['car', 'bicycle'])
  159. coco_metric.process({}, [
  160. dict(
  161. pred_instances=dummy_pred, img_id=0, ori_shape=(640, 640))
  162. ])
  163. coco_metric.evaluate(size=1)
  164. # test custom metric_items
  165. coco_metric = CocoMetric(
  166. ann_file=fake_json_file, metric_items=['mAP_m'])
  167. coco_metric.dataset_meta = dict(classes=['car', 'bicycle'])
  168. coco_metric.process(
  169. {},
  170. [dict(pred_instances=dummy_pred, img_id=0, ori_shape=(640, 640))])
  171. eval_results = coco_metric.evaluate(size=1)
  172. target = {
  173. 'coco/bbox_mAP_m': 1.0,
  174. }
  175. self.assertDictEqual(eval_results, target)
  176. def test_classwise_evaluate(self):
  177. # create dummy data
  178. fake_json_file = osp.join(self.tmp_dir.name, 'fake_data.json')
  179. self._create_dummy_coco_json(fake_json_file)
  180. dummy_pred = self._create_dummy_results()
  181. # test single coco dataset evaluation
  182. coco_metric = CocoMetric(
  183. ann_file=fake_json_file, metric='bbox', classwise=True)
  184. # coco_metric1 = CocoMetric(
  185. # ann_file=fake_json_file, metric='bbox', classwise=True)
  186. coco_metric.dataset_meta = dict(classes=['car', 'bicycle'])
  187. coco_metric.process(
  188. {},
  189. [dict(pred_instances=dummy_pred, img_id=0, ori_shape=(640, 640))])
  190. eval_results = coco_metric.evaluate(size=1)
  191. target = {
  192. 'coco/bbox_mAP': 1.0,
  193. 'coco/bbox_mAP_50': 1.0,
  194. 'coco/bbox_mAP_75': 1.0,
  195. 'coco/bbox_mAP_s': 1.0,
  196. 'coco/bbox_mAP_m': 1.0,
  197. 'coco/bbox_mAP_l': 1.0,
  198. 'coco/car_precision': 1.0,
  199. 'coco/bicycle_precision': 1.0,
  200. }
  201. self.assertDictEqual(eval_results, target)
  202. def test_manually_set_iou_thrs(self):
  203. # create dummy data
  204. fake_json_file = osp.join(self.tmp_dir.name, 'fake_data.json')
  205. self._create_dummy_coco_json(fake_json_file)
  206. # test single coco dataset evaluation
  207. coco_metric = CocoMetric(
  208. ann_file=fake_json_file, metric='bbox', iou_thrs=[0.3, 0.6])
  209. coco_metric.dataset_meta = dict(classes=['car', 'bicycle'])
  210. self.assertEqual(coco_metric.iou_thrs, [0.3, 0.6])
  211. def test_fast_eval_recall(self):
  212. # create dummy data
  213. fake_json_file = osp.join(self.tmp_dir.name, 'fake_data.json')
  214. self._create_dummy_coco_json(fake_json_file)
  215. dummy_pred = self._create_dummy_results()
  216. # test default proposal nums
  217. coco_metric = CocoMetric(
  218. ann_file=fake_json_file, metric='proposal_fast')
  219. coco_metric.dataset_meta = dict(classes=['car', 'bicycle'])
  220. coco_metric.process(
  221. {},
  222. [dict(pred_instances=dummy_pred, img_id=0, ori_shape=(640, 640))])
  223. eval_results = coco_metric.evaluate(size=1)
  224. target = {'coco/AR@100': 1.0, 'coco/AR@300': 1.0, 'coco/AR@1000': 1.0}
  225. self.assertDictEqual(eval_results, target)
  226. # test manually set proposal nums
  227. coco_metric = CocoMetric(
  228. ann_file=fake_json_file,
  229. metric='proposal_fast',
  230. proposal_nums=(2, 4))
  231. coco_metric.dataset_meta = dict(classes=['car', 'bicycle'])
  232. coco_metric.process(
  233. {},
  234. [dict(pred_instances=dummy_pred, img_id=0, ori_shape=(640, 640))])
  235. eval_results = coco_metric.evaluate(size=1)
  236. target = {'coco/AR@2': 0.5, 'coco/AR@4': 1.0}
  237. self.assertDictEqual(eval_results, target)
  238. def test_evaluate_proposal(self):
  239. # create dummy data
  240. fake_json_file = osp.join(self.tmp_dir.name, 'fake_data.json')
  241. self._create_dummy_coco_json(fake_json_file)
  242. dummy_pred = self._create_dummy_results()
  243. coco_metric = CocoMetric(ann_file=fake_json_file, metric='proposal')
  244. coco_metric.dataset_meta = dict(classes=['car', 'bicycle'])
  245. coco_metric.process(
  246. {},
  247. [dict(pred_instances=dummy_pred, img_id=0, ori_shape=(640, 640))])
  248. eval_results = coco_metric.evaluate(size=1)
  249. print(eval_results)
  250. target = {
  251. 'coco/AR@100': 1,
  252. 'coco/AR@300': 1.0,
  253. 'coco/AR@1000': 1.0,
  254. 'coco/AR_s@1000': 1.0,
  255. 'coco/AR_m@1000': 1.0,
  256. 'coco/AR_l@1000': 1.0
  257. }
  258. self.assertDictEqual(eval_results, target)
  259. def test_empty_results(self):
  260. # create dummy data
  261. fake_json_file = osp.join(self.tmp_dir.name, 'fake_data.json')
  262. self._create_dummy_coco_json(fake_json_file)
  263. coco_metric = CocoMetric(ann_file=fake_json_file, metric='bbox')
  264. coco_metric.dataset_meta = dict(classes=['car', 'bicycle'])
  265. bboxes = np.zeros((0, 4))
  266. labels = np.array([])
  267. scores = np.array([])
  268. dummy_mask = np.zeros((0, 10, 10), dtype=np.uint8)
  269. empty_pred = dict(
  270. bboxes=torch.from_numpy(bboxes),
  271. scores=torch.from_numpy(scores),
  272. labels=torch.from_numpy(labels),
  273. masks=torch.from_numpy(dummy_mask))
  274. coco_metric.process(
  275. {},
  276. [dict(pred_instances=empty_pred, img_id=0, ori_shape=(640, 640))])
  277. # coco api Index error will be caught
  278. coco_metric.evaluate(size=1)
  279. def test_evaluate_without_json(self):
  280. dummy_pred = self._create_dummy_results()
  281. dummy_mask = np.zeros((10, 10), order='F', dtype=np.uint8)
  282. dummy_mask[:5, :5] = 1
  283. rle_mask = mask_util.encode(dummy_mask)
  284. rle_mask['counts'] = rle_mask['counts'].decode('utf-8')
  285. instances = [{
  286. 'bbox_label': 0,
  287. 'bbox': [50, 60, 70, 80],
  288. 'ignore_flag': 0,
  289. 'mask': rle_mask,
  290. }, {
  291. 'bbox_label': 0,
  292. 'bbox': [100, 120, 130, 150],
  293. 'ignore_flag': 0,
  294. 'mask': rle_mask,
  295. }, {
  296. 'bbox_label': 1,
  297. 'bbox': [150, 160, 190, 200],
  298. 'ignore_flag': 0,
  299. 'mask': rle_mask,
  300. }, {
  301. 'bbox_label': 0,
  302. 'bbox': [250, 260, 350, 360],
  303. 'ignore_flag': 0,
  304. 'mask': rle_mask,
  305. }]
  306. coco_metric = CocoMetric(
  307. ann_file=None,
  308. metric=['bbox', 'segm'],
  309. classwise=False,
  310. outfile_prefix=f'{self.tmp_dir.name}/test')
  311. coco_metric.dataset_meta = dict(classes=['car', 'bicycle'])
  312. coco_metric.process({}, [
  313. dict(
  314. pred_instances=dummy_pred,
  315. img_id=0,
  316. ori_shape=(640, 640),
  317. instances=instances)
  318. ])
  319. eval_results = coco_metric.evaluate(size=1)
  320. print(eval_results)
  321. target = {
  322. 'coco/bbox_mAP': 1.0,
  323. 'coco/bbox_mAP_50': 1.0,
  324. 'coco/bbox_mAP_75': 1.0,
  325. 'coco/bbox_mAP_s': 1.0,
  326. 'coco/bbox_mAP_m': 1.0,
  327. 'coco/bbox_mAP_l': 1.0,
  328. 'coco/segm_mAP': 1.0,
  329. 'coco/segm_mAP_50': 1.0,
  330. 'coco/segm_mAP_75': 1.0,
  331. 'coco/segm_mAP_s': 1.0,
  332. 'coco/segm_mAP_m': 1.0,
  333. 'coco/segm_mAP_l': 1.0,
  334. }
  335. self.assertDictEqual(eval_results, target)
  336. self.assertTrue(
  337. osp.isfile(osp.join(self.tmp_dir.name, 'test.bbox.json')))
  338. self.assertTrue(
  339. osp.isfile(osp.join(self.tmp_dir.name, 'test.segm.json')))
  340. self.assertTrue(
  341. osp.isfile(osp.join(self.tmp_dir.name, 'test.gt.json')))
  342. def test_format_only(self):
  343. # create dummy data
  344. fake_json_file = osp.join(self.tmp_dir.name, 'fake_data.json')
  345. self._create_dummy_coco_json(fake_json_file)
  346. dummy_pred = self._create_dummy_results()
  347. with self.assertRaises(AssertionError):
  348. CocoMetric(
  349. ann_file=fake_json_file,
  350. classwise=False,
  351. format_only=True,
  352. outfile_prefix=None)
  353. coco_metric = CocoMetric(
  354. ann_file=fake_json_file,
  355. metric='bbox',
  356. classwise=False,
  357. format_only=True,
  358. outfile_prefix=f'{self.tmp_dir.name}/test')
  359. coco_metric.dataset_meta = dict(classes=['car', 'bicycle'])
  360. coco_metric.process(
  361. {},
  362. [dict(pred_instances=dummy_pred, img_id=0, ori_shape=(640, 640))])
  363. eval_results = coco_metric.evaluate(size=1)
  364. self.assertDictEqual(eval_results, dict())
  365. self.assertTrue(osp.exists(f'{self.tmp_dir.name}/test.bbox.json'))