test_dump_det_results.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. import tempfile
  4. from unittest import TestCase
  5. import torch
  6. from mmengine.fileio import load
  7. from torch import Tensor
  8. from mmdet.evaluation import DumpDetResults
  9. from mmdet.structures.mask import encode_mask_results
  10. class TestDumpResults(TestCase):
  11. def test_init(self):
  12. with self.assertRaisesRegex(ValueError,
  13. 'The output file must be a pkl file.'):
  14. DumpDetResults(out_file_path='./results.json')
  15. def test_process(self):
  16. metric = DumpDetResults(out_file_path='./results.pkl')
  17. data_samples = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
  18. metric.process(None, data_samples)
  19. self.assertEqual(len(metric.results), 1)
  20. self.assertEqual(metric.results[0]['data'][0].device,
  21. torch.device('cpu'))
  22. metric = DumpDetResults(out_file_path='./results.pkl')
  23. masks = torch.zeros(10, 10, 4)
  24. data_samples = [
  25. dict(pred_instances=dict(masks=masks), gt_instances=[])
  26. ]
  27. metric.process(None, data_samples)
  28. self.assertEqual(len(metric.results), 1)
  29. self.assertEqual(metric.results[0]['pred_instances']['masks'],
  30. encode_mask_results(masks.numpy()))
  31. self.assertNotIn('gt_instances', metric.results[0])
  32. def test_compute_metrics(self):
  33. temp_dir = tempfile.TemporaryDirectory()
  34. path = osp.join(temp_dir.name, 'results.pkl')
  35. metric = DumpDetResults(out_file_path=path)
  36. data_samples = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
  37. metric.process(None, data_samples)
  38. metric.compute_metrics(metric.results)
  39. self.assertTrue(osp.isfile(path))
  40. results = load(path)
  41. self.assertEqual(len(results), 1)
  42. self.assertEqual(results[0]['data'][0].device, torch.device('cpu'))
  43. temp_dir.cleanup()