class_aware_sampler.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. from typing import Dict, Iterator, Optional, Union
  4. import numpy as np
  5. import torch
  6. from mmengine.dataset import BaseDataset
  7. from mmengine.dist import get_dist_info, sync_random_seed
  8. from torch.utils.data import Sampler
  9. from mmdet.registry import DATA_SAMPLERS
  10. @DATA_SAMPLERS.register_module()
  11. class ClassAwareSampler(Sampler):
  12. r"""Sampler that restricts data loading to the label of the dataset.
  13. A class-aware sampling strategy to effectively tackle the
  14. non-uniform class distribution. The length of the training data is
  15. consistent with source data. Simple improvements based on `Relay
  16. Backpropagation for Effective Learning of Deep Convolutional
  17. Neural Networks <https://arxiv.org/abs/1512.05830>`_
  18. The implementation logic is referred to
  19. https://github.com/Sense-X/TSD/blob/master/mmdet/datasets/samplers/distributed_classaware_sampler.py
  20. Args:
  21. dataset: Dataset used for sampling.
  22. seed (int, optional): random seed used to shuffle the sampler.
  23. This number should be identical across all
  24. processes in the distributed group. Defaults to None.
  25. num_sample_class (int): The number of samples taken from each
  26. per-label list. Defaults to 1.
  27. """
  28. def __init__(self,
  29. dataset: BaseDataset,
  30. seed: Optional[int] = None,
  31. num_sample_class: int = 1) -> None:
  32. rank, world_size = get_dist_info()
  33. self.rank = rank
  34. self.world_size = world_size
  35. self.dataset = dataset
  36. self.epoch = 0
  37. # Must be the same across all workers. If None, will use a
  38. # random seed shared among workers
  39. # (require synchronization among all workers)
  40. if seed is None:
  41. seed = sync_random_seed()
  42. self.seed = seed
  43. # The number of samples taken from each per-label list
  44. assert num_sample_class > 0 and isinstance(num_sample_class, int)
  45. self.num_sample_class = num_sample_class
  46. # Get per-label image list from dataset
  47. self.cat_dict = self.get_cat2imgs()
  48. self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / world_size))
  49. self.total_size = self.num_samples * self.world_size
  50. # get number of images containing each category
  51. self.num_cat_imgs = [len(x) for x in self.cat_dict.values()]
  52. # filter labels without images
  53. self.valid_cat_inds = [
  54. i for i, length in enumerate(self.num_cat_imgs) if length != 0
  55. ]
  56. self.num_classes = len(self.valid_cat_inds)
  57. def get_cat2imgs(self) -> Dict[int, list]:
  58. """Get a dict with class as key and img_ids as values.
  59. Returns:
  60. dict[int, list]: A dict of per-label image list,
  61. the item of the dict indicates a label index,
  62. corresponds to the image index that contains the label.
  63. """
  64. classes = self.dataset.metainfo.get('classes', None)
  65. if classes is None:
  66. raise ValueError('dataset metainfo must contain `classes`')
  67. # sort the label index
  68. cat2imgs = {i: [] for i in range(len(classes))}
  69. for i in range(len(self.dataset)):
  70. cat_ids = set(self.dataset.get_cat_ids(i))
  71. for cat in cat_ids:
  72. cat2imgs[cat].append(i)
  73. return cat2imgs
  74. def __iter__(self) -> Iterator[int]:
  75. # deterministically shuffle based on epoch
  76. g = torch.Generator()
  77. g.manual_seed(self.epoch + self.seed)
  78. # initialize label list
  79. label_iter_list = RandomCycleIter(self.valid_cat_inds, generator=g)
  80. # initialize each per-label image list
  81. data_iter_dict = dict()
  82. for i in self.valid_cat_inds:
  83. data_iter_dict[i] = RandomCycleIter(self.cat_dict[i], generator=g)
  84. def gen_cat_img_inds(cls_list, data_dict, num_sample_cls):
  85. """Traverse the categories and extract `num_sample_cls` image
  86. indexes of the corresponding categories one by one."""
  87. id_indices = []
  88. for _ in range(len(cls_list)):
  89. cls_idx = next(cls_list)
  90. for _ in range(num_sample_cls):
  91. id = next(data_dict[cls_idx])
  92. id_indices.append(id)
  93. return id_indices
  94. # deterministically shuffle based on epoch
  95. num_bins = int(
  96. math.ceil(self.total_size * 1.0 / self.num_classes /
  97. self.num_sample_class))
  98. indices = []
  99. for i in range(num_bins):
  100. indices += gen_cat_img_inds(label_iter_list, data_iter_dict,
  101. self.num_sample_class)
  102. # fix extra samples to make it evenly divisible
  103. if len(indices) >= self.total_size:
  104. indices = indices[:self.total_size]
  105. else:
  106. indices += indices[:(self.total_size - len(indices))]
  107. assert len(indices) == self.total_size
  108. # subsample
  109. offset = self.num_samples * self.rank
  110. indices = indices[offset:offset + self.num_samples]
  111. assert len(indices) == self.num_samples
  112. return iter(indices)
  113. def __len__(self) -> int:
  114. """The number of samples in this rank."""
  115. return self.num_samples
  116. def set_epoch(self, epoch: int) -> None:
  117. """Sets the epoch for this sampler.
  118. When :attr:`shuffle=True`, this ensures all replicas use a different
  119. random ordering for each epoch. Otherwise, the next iteration of this
  120. sampler will yield the same ordering.
  121. Args:
  122. epoch (int): Epoch number.
  123. """
  124. self.epoch = epoch
  125. class RandomCycleIter:
  126. """Shuffle the list and do it again after the list have traversed.
  127. The implementation logic is referred to
  128. https://github.com/wutong16/DistributionBalancedLoss/blob/master/mllt/datasets/loader/sampler.py
  129. Example:
  130. >>> label_list = [0, 1, 2, 4, 5]
  131. >>> g = torch.Generator()
  132. >>> g.manual_seed(0)
  133. >>> label_iter_list = RandomCycleIter(label_list, generator=g)
  134. >>> index = next(label_iter_list)
  135. Args:
  136. data (list or ndarray): The data that needs to be shuffled.
  137. generator: An torch.Generator object, which is used in setting the seed
  138. for generating random numbers.
  139. """ # noqa: W605
  140. def __init__(self,
  141. data: Union[list, np.ndarray],
  142. generator: torch.Generator = None) -> None:
  143. self.data = data
  144. self.length = len(data)
  145. self.index = torch.randperm(self.length, generator=generator).numpy()
  146. self.i = 0
  147. self.generator = generator
  148. def __iter__(self) -> Iterator:
  149. return self
  150. def __len__(self) -> int:
  151. return len(self.data)
  152. def __next__(self):
  153. if self.i == self.length:
  154. self.index = torch.randperm(
  155. self.length, generator=self.generator).numpy()
  156. self.i = 0
  157. idx = self.data[self.index[self.i]]
  158. self.i += 1
  159. return idx