1234567891011121314151617181920212223242526272829303132333435363738 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import unittest
- from mmdet.datasets import VOCDataset
- class TestVOCDataset(unittest.TestCase):
- def test_voc2007_init(self):
- dataset = VOCDataset(
- data_root='tests/data/VOCdevkit/',
- ann_file='VOC2007/ImageSets/Main/trainval.txt',
- data_prefix=dict(sub_data_root='VOC2007/'),
- filter_cfg=dict(
- filter_empty_gt=True, min_size=32, bbox_min_size=32),
- pipeline=[])
- dataset.full_init()
- self.assertEqual(len(dataset), 1)
- data_list = dataset.load_data_list()
- self.assertEqual(len(data_list), 1)
- self.assertEqual(len(data_list[0]['instances']), 2)
- self.assertEqual(dataset.get_cat_ids(0), [11, 14])
- def test_voc2012_init(self):
- dataset = VOCDataset(
- data_root='tests/data/VOCdevkit/',
- ann_file='VOC2012/ImageSets/Main/trainval.txt',
- data_prefix=dict(sub_data_root='VOC2012/'),
- filter_cfg=dict(filter_empty_gt=True, min_size=32),
- pipeline=[])
- dataset.full_init()
- self.assertEqual(len(dataset), 1)
- data_list = dataset.load_data_list()
- self.assertEqual(len(data_list), 1)
- self.assertEqual(len(data_list[0]['instances']), 1)
- self.assertEqual(dataset.get_cat_ids(0), [18])
|