test_cityscapes_metric.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import os
  2. import os.path as osp
  3. import tempfile
  4. import unittest
  5. import numpy as np
  6. import torch
  7. from PIL import Image
  8. from mmdet.evaluation import CityScapesMetric
  9. try:
  10. import cityscapesscripts
  11. except ImportError:
  12. cityscapesscripts = None
  13. class TestCityScapesMetric(unittest.TestCase):
  14. def setUp(self):
  15. self.tmp_dir = tempfile.TemporaryDirectory()
  16. def tearDown(self):
  17. self.tmp_dir.cleanup()
  18. @unittest.skipIf(cityscapesscripts is None,
  19. 'cityscapesscripts is not installed.')
  20. def test_init(self):
  21. # test with outfile_prefix = None
  22. with self.assertRaises(AssertionError):
  23. CityScapesMetric(outfile_prefix=None)
  24. @unittest.skipIf(cityscapesscripts is None,
  25. 'cityscapesscripts is not installed.')
  26. def test_evaluate(self):
  27. dummy_mask1 = np.zeros((1, 20, 20), dtype=np.uint8)
  28. dummy_mask1[:, :10, :10] = 1
  29. dummy_mask2 = np.zeros((1, 20, 20), dtype=np.uint8)
  30. dummy_mask2[:, :10, :10] = 1
  31. self.outfile_prefix = osp.join(self.tmp_dir.name, 'test')
  32. self.seg_prefix = osp.join(self.tmp_dir.name, 'cityscapes/gtFine/val')
  33. city = 'lindau'
  34. sequenceNb = '000000'
  35. frameNb = '000019'
  36. img_name1 = f'{city}_{sequenceNb}_{frameNb}_gtFine_instanceIds.png'
  37. img_path1 = osp.join(self.seg_prefix, city, img_name1)
  38. frameNb = '000020'
  39. img_name2 = f'{city}_{sequenceNb}_{frameNb}_gtFine_instanceIds.png'
  40. img_path2 = osp.join(self.seg_prefix, city, img_name2)
  41. os.makedirs(osp.join(self.seg_prefix, city))
  42. masks1 = np.zeros((20, 20), dtype=np.int32)
  43. masks1[:10, :10] = 24 * 1000
  44. Image.fromarray(masks1).save(img_path1)
  45. masks2 = np.zeros((20, 20), dtype=np.int32)
  46. masks2[:10, :10] = 24 * 1000 + 1
  47. Image.fromarray(masks2).save(img_path2)
  48. data_samples = [{
  49. 'img_path': img_path1,
  50. 'pred_instances': {
  51. 'scores': torch.from_numpy(np.array([1.0])),
  52. 'labels': torch.from_numpy(np.array([0])),
  53. 'masks': torch.from_numpy(dummy_mask1)
  54. }
  55. }, {
  56. 'img_path': img_path2,
  57. 'pred_instances': {
  58. 'scores': torch.from_numpy(np.array([0.98])),
  59. 'labels': torch.from_numpy(np.array([1])),
  60. 'masks': torch.from_numpy(dummy_mask2)
  61. }
  62. }]
  63. target = {'cityscapes/mAP': 0.5, 'cityscapes/AP@50': 0.5}
  64. metric = CityScapesMetric(
  65. seg_prefix=self.seg_prefix,
  66. format_only=False,
  67. outfile_prefix=self.outfile_prefix)
  68. metric.dataset_meta = dict(
  69. classes=('person', 'rider', 'car', 'truck', 'bus', 'train',
  70. 'motorcycle', 'bicycle'))
  71. metric.process({}, data_samples)
  72. results = metric.evaluate(size=2)
  73. self.assertDictEqual(results, target)
  74. del metric
  75. self.assertTrue(not osp.exists('{self.outfile_prefix}.results'))
  76. # test format_only
  77. metric = CityScapesMetric(
  78. seg_prefix=self.seg_prefix,
  79. format_only=True,
  80. outfile_prefix=self.outfile_prefix)
  81. metric.dataset_meta = dict(
  82. classes=('person', 'rider', 'car', 'truck', 'bus', 'train',
  83. 'motorcycle', 'bicycle'))
  84. metric.process({}, data_samples)
  85. results = metric.evaluate(size=2)
  86. self.assertDictEqual(results, dict())