# Copyright (c) OpenMMLab. All rights reserved. from unittest import TestCase from unittest.mock import patch import numpy as np from mmengine.dataset import DefaultSampler from torch.utils.data import Dataset from mmdet.datasets.samplers import AspectRatioBatchSampler class DummyDataset(Dataset): def __init__(self, length): self.length = length self.shapes = np.random.random((length, 2)) def __len__(self): return self.length def __getitem__(self, idx): return self.shapes[idx] def get_data_info(self, idx): return dict(width=self.shapes[idx][0], height=self.shapes[idx][1]) class TestAspectRatioBatchSampler(TestCase): @patch('mmengine.dist.get_dist_info', return_value=(0, 1)) def setUp(self, mock): self.length = 100 self.dataset = DummyDataset(self.length) self.sampler = DefaultSampler(self.dataset, shuffle=False) def test_invalid_inputs(self): with self.assertRaisesRegex( ValueError, 'batch_size should be a positive integer value'): AspectRatioBatchSampler(self.sampler, batch_size=-1) with self.assertRaisesRegex( TypeError, 'sampler should be an instance of ``Sampler``'): AspectRatioBatchSampler(None, batch_size=1) def test_divisible_batch(self): batch_size = 5 batch_sampler = AspectRatioBatchSampler( self.sampler, batch_size=batch_size, drop_last=True) self.assertEqual(len(batch_sampler), self.length // batch_size) for batch_idxs in batch_sampler: self.assertEqual(len(batch_idxs), batch_size) batch = [self.dataset[idx] for idx in batch_idxs] flag = batch[0][0] < batch[0][1] for i in range(1, batch_size): self.assertEqual(batch[i][0] < batch[i][1], flag) def test_indivisible_batch(self): batch_size = 7 batch_sampler = AspectRatioBatchSampler( self.sampler, batch_size=batch_size, drop_last=False) all_batch_idxs = list(batch_sampler) self.assertEqual( len(batch_sampler), (self.length + batch_size - 1) // batch_size) self.assertEqual( len(all_batch_idxs), (self.length + batch_size - 1) // batch_size) batch_sampler = AspectRatioBatchSampler( self.sampler, batch_size=batch_size, drop_last=True) all_batch_idxs = list(batch_sampler) self.assertEqual(len(batch_sampler), self.length // batch_size) self.assertEqual(len(all_batch_idxs), self.length // batch_size) # the last batch may not have the same aspect ratio for batch_idxs in all_batch_idxs[:-1]: self.assertEqual(len(batch_idxs), batch_size) batch = [self.dataset[idx] for idx in batch_idxs] flag = batch[0][0] < batch[0][1] for i in range(1, batch_size): self.assertEqual(batch[i][0] < batch[i][1], flag)