1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import os.path as osp
- import tempfile
- from unittest import TestCase
- import torch
- from mmengine.fileio import load
- from torch import Tensor
- from mmdet.evaluation import DumpDetResults
- from mmdet.structures.mask import encode_mask_results
- class TestDumpResults(TestCase):
- def test_init(self):
- with self.assertRaisesRegex(ValueError,
- 'The output file must be a pkl file.'):
- DumpDetResults(out_file_path='./results.json')
- def test_process(self):
- metric = DumpDetResults(out_file_path='./results.pkl')
- data_samples = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
- metric.process(None, data_samples)
- self.assertEqual(len(metric.results), 1)
- self.assertEqual(metric.results[0]['data'][0].device,
- torch.device('cpu'))
- metric = DumpDetResults(out_file_path='./results.pkl')
- masks = torch.zeros(10, 10, 4)
- data_samples = [
- dict(pred_instances=dict(masks=masks), gt_instances=[])
- ]
- metric.process(None, data_samples)
- self.assertEqual(len(metric.results), 1)
- self.assertEqual(metric.results[0]['pred_instances']['masks'],
- encode_mask_results(masks.numpy()))
- self.assertNotIn('gt_instances', metric.results[0])
- def test_compute_metrics(self):
- temp_dir = tempfile.TemporaryDirectory()
- path = osp.join(temp_dir.name, 'results.pkl')
- metric = DumpDetResults(out_file_path=path)
- data_samples = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
- metric.process(None, data_samples)
- metric.compute_metrics(metric.results)
- self.assertTrue(osp.isfile(path))
- results = load(path)
- self.assertEqual(len(results), 1)
- self.assertEqual(results[0]['data'][0].device, torch.device('cpu'))
- temp_dir.cleanup()
|