test_pascal_voc.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import unittest
  3. from mmdet.datasets import VOCDataset
  4. class TestVOCDataset(unittest.TestCase):
  5. def test_voc2007_init(self):
  6. dataset = VOCDataset(
  7. data_root='tests/data/VOCdevkit/',
  8. ann_file='VOC2007/ImageSets/Main/trainval.txt',
  9. data_prefix=dict(sub_data_root='VOC2007/'),
  10. filter_cfg=dict(
  11. filter_empty_gt=True, min_size=32, bbox_min_size=32),
  12. pipeline=[])
  13. dataset.full_init()
  14. self.assertEqual(len(dataset), 1)
  15. data_list = dataset.load_data_list()
  16. self.assertEqual(len(data_list), 1)
  17. self.assertEqual(len(data_list[0]['instances']), 2)
  18. self.assertEqual(dataset.get_cat_ids(0), [11, 14])
  19. def test_voc2012_init(self):
  20. dataset = VOCDataset(
  21. data_root='tests/data/VOCdevkit/',
  22. ann_file='VOC2012/ImageSets/Main/trainval.txt',
  23. data_prefix=dict(sub_data_root='VOC2012/'),
  24. filter_cfg=dict(filter_empty_gt=True, min_size=32),
  25. pipeline=[])
  26. dataset.full_init()
  27. self.assertEqual(len(dataset), 1)
  28. data_list = dataset.load_data_list()
  29. self.assertEqual(len(data_list), 1)
  30. self.assertEqual(len(data_list[0]['instances']), 1)
  31. self.assertEqual(dataset.get_cat_ids(0), [18])