test_batch_sampler.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. from unittest.mock import patch
  4. import numpy as np
  5. from mmengine.dataset import DefaultSampler
  6. from torch.utils.data import Dataset
  7. from mmdet.datasets.samplers import AspectRatioBatchSampler
  8. class DummyDataset(Dataset):
  9. def __init__(self, length):
  10. self.length = length
  11. self.shapes = np.random.random((length, 2))
  12. def __len__(self):
  13. return self.length
  14. def __getitem__(self, idx):
  15. return self.shapes[idx]
  16. def get_data_info(self, idx):
  17. return dict(width=self.shapes[idx][0], height=self.shapes[idx][1])
  18. class TestAspectRatioBatchSampler(TestCase):
  19. @patch('mmengine.dist.get_dist_info', return_value=(0, 1))
  20. def setUp(self, mock):
  21. self.length = 100
  22. self.dataset = DummyDataset(self.length)
  23. self.sampler = DefaultSampler(self.dataset, shuffle=False)
  24. def test_invalid_inputs(self):
  25. with self.assertRaisesRegex(
  26. ValueError, 'batch_size should be a positive integer value'):
  27. AspectRatioBatchSampler(self.sampler, batch_size=-1)
  28. with self.assertRaisesRegex(
  29. TypeError, 'sampler should be an instance of ``Sampler``'):
  30. AspectRatioBatchSampler(None, batch_size=1)
  31. def test_divisible_batch(self):
  32. batch_size = 5
  33. batch_sampler = AspectRatioBatchSampler(
  34. self.sampler, batch_size=batch_size, drop_last=True)
  35. self.assertEqual(len(batch_sampler), self.length // batch_size)
  36. for batch_idxs in batch_sampler:
  37. self.assertEqual(len(batch_idxs), batch_size)
  38. batch = [self.dataset[idx] for idx in batch_idxs]
  39. flag = batch[0][0] < batch[0][1]
  40. for i in range(1, batch_size):
  41. self.assertEqual(batch[i][0] < batch[i][1], flag)
  42. def test_indivisible_batch(self):
  43. batch_size = 7
  44. batch_sampler = AspectRatioBatchSampler(
  45. self.sampler, batch_size=batch_size, drop_last=False)
  46. all_batch_idxs = list(batch_sampler)
  47. self.assertEqual(
  48. len(batch_sampler), (self.length + batch_size - 1) // batch_size)
  49. self.assertEqual(
  50. len(all_batch_idxs), (self.length + batch_size - 1) // batch_size)
  51. batch_sampler = AspectRatioBatchSampler(
  52. self.sampler, batch_size=batch_size, drop_last=True)
  53. all_batch_idxs = list(batch_sampler)
  54. self.assertEqual(len(batch_sampler), self.length // batch_size)
  55. self.assertEqual(len(all_batch_idxs), self.length // batch_size)
  56. # the last batch may not have the same aspect ratio
  57. for batch_idxs in all_batch_idxs[:-1]:
  58. self.assertEqual(len(batch_idxs), batch_size)
  59. batch = [self.dataset[idx] for idx in batch_idxs]
  60. flag = batch[0][0] < batch[0][1]
  61. for i in range(1, batch_size):
  62. self.assertEqual(batch[i][0] < batch[i][1], flag)