test_multi_source_sampler.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import bisect
  3. from unittest import TestCase
  4. from unittest.mock import patch
  5. import numpy as np
  6. from torch.utils.data import ConcatDataset, Dataset
  7. from mmdet.datasets.samplers import GroupMultiSourceSampler, MultiSourceSampler
  8. class DummyDataset(Dataset):
  9. def __init__(self, length, flag):
  10. self.length = length
  11. self.flag = flag
  12. self.shapes = np.random.random((length, 2))
  13. def __len__(self):
  14. return self.length
  15. def __getitem__(self, idx):
  16. return self.shapes[idx]
  17. def get_data_info(self, idx):
  18. return dict(
  19. width=self.shapes[idx][0],
  20. height=self.shapes[idx][1],
  21. flag=self.flag)
  22. class DummyConcatDataset(ConcatDataset):
  23. def _get_ori_dataset_idx(self, idx):
  24. dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
  25. sample_idx = idx if dataset_idx == 0 else idx - self.cumulative_sizes[
  26. dataset_idx - 1]
  27. return dataset_idx, sample_idx
  28. def get_data_info(self, idx: int):
  29. dataset_idx, sample_idx = self._get_ori_dataset_idx(idx)
  30. return self.datasets[dataset_idx].get_data_info(sample_idx)
  31. class TestMultiSourceSampler(TestCase):
  32. @patch('mmengine.dist.get_dist_info', return_value=(7, 8))
  33. def setUp(self, mock):
  34. self.length_a = 100
  35. self.dataset_a = DummyDataset(self.length_a, flag='a')
  36. self.length_b = 1000
  37. self.dataset_b = DummyDataset(self.length_b, flag='b')
  38. self.dataset = DummyConcatDataset([self.dataset_a, self.dataset_b])
  39. def test_multi_source_sampler(self):
  40. # test dataset is not ConcatDataset
  41. with self.assertRaises(AssertionError):
  42. MultiSourceSampler(
  43. self.dataset_a, batch_size=5, source_ratio=[1, 4])
  44. # test invalid batch_size
  45. with self.assertRaises(AssertionError):
  46. MultiSourceSampler(
  47. self.dataset_a, batch_size=-5, source_ratio=[1, 4])
  48. # test source_ratio longer then dataset
  49. with self.assertRaises(AssertionError):
  50. MultiSourceSampler(
  51. self.dataset, batch_size=5, source_ratio=[1, 2, 4])
  52. sampler = MultiSourceSampler(
  53. self.dataset, batch_size=5, source_ratio=[1, 4])
  54. sampler = iter(sampler)
  55. flags = []
  56. for i in range(100):
  57. idx = next(sampler)
  58. flags.append(self.dataset.get_data_info(idx)['flag'])
  59. flags_gt = ['a', 'b', 'b', 'b', 'b'] * 20
  60. self.assertEqual(flags, flags_gt)
  61. class TestGroupMultiSourceSampler(TestCase):
  62. @patch('mmengine.dist.get_dist_info', return_value=(7, 8))
  63. def setUp(self, mock):
  64. self.length_a = 100
  65. self.dataset_a = DummyDataset(self.length_a, flag='a')
  66. self.length_b = 1000
  67. self.dataset_b = DummyDataset(self.length_b, flag='b')
  68. self.dataset = DummyConcatDataset([self.dataset_a, self.dataset_b])
  69. def test_group_multi_source_sampler(self):
  70. sampler = GroupMultiSourceSampler(
  71. self.dataset, batch_size=5, source_ratio=[1, 4])
  72. sampler = iter(sampler)
  73. flags = []
  74. groups = []
  75. for i in range(100):
  76. idx = next(sampler)
  77. data_info = self.dataset.get_data_info(idx)
  78. flags.append(data_info['flag'])
  79. group = 0 if data_info['width'] < data_info['height'] else 1
  80. groups.append(group)
  81. flags_gt = ['a', 'b', 'b', 'b', 'b'] * 20
  82. self.assertEqual(flags, flags_gt)
  83. groups = set(
  84. [sum(x) for x in (groups[k:k + 5] for k in range(0, 100, 5))])
  85. groups_gt = set([0, 5])
  86. self.assertEqual(groups, groups_gt)