dataset_wrappers.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import collections
  3. import copy
  4. from typing import Sequence, Union
  5. from mmengine.dataset import BaseDataset, force_full_init
  6. from mmdet.registry import DATASETS, TRANSFORMS
  7. @DATASETS.register_module()
  8. class MultiImageMixDataset:
  9. """A wrapper of multiple images mixed dataset.
  10. Suitable for training on multiple images mixed data augmentation like
  11. mosaic and mixup. For the augmentation pipeline of mixed image data,
  12. the `get_indexes` method needs to be provided to obtain the image
  13. indexes, and you can set `skip_flags` to change the pipeline running
  14. process. At the same time, we provide the `dynamic_scale` parameter
  15. to dynamically change the output image size.
  16. Args:
  17. dataset (:obj:`CustomDataset`): The dataset to be mixed.
  18. pipeline (Sequence[dict]): Sequence of transform object or
  19. config dict to be composed.
  20. dynamic_scale (tuple[int], optional): The image scale can be changed
  21. dynamically. Default to None. It is deprecated.
  22. skip_type_keys (list[str], optional): Sequence of type string to
  23. be skip pipeline. Default to None.
  24. max_refetch (int): The maximum number of retry iterations for getting
  25. valid results from the pipeline. If the number of iterations is
  26. greater than `max_refetch`, but results is still None, then the
  27. iteration is terminated and raise the error. Default: 15.
  28. """
  29. def __init__(self,
  30. dataset: Union[BaseDataset, dict],
  31. pipeline: Sequence[str],
  32. skip_type_keys: Union[Sequence[str], None] = None,
  33. max_refetch: int = 15,
  34. lazy_init: bool = False) -> None:
  35. assert isinstance(pipeline, collections.abc.Sequence)
  36. if skip_type_keys is not None:
  37. assert all([
  38. isinstance(skip_type_key, str)
  39. for skip_type_key in skip_type_keys
  40. ])
  41. self._skip_type_keys = skip_type_keys
  42. self.pipeline = []
  43. self.pipeline_types = []
  44. for transform in pipeline:
  45. if isinstance(transform, dict):
  46. self.pipeline_types.append(transform['type'])
  47. transform = TRANSFORMS.build(transform)
  48. self.pipeline.append(transform)
  49. else:
  50. raise TypeError('pipeline must be a dict')
  51. self.dataset: BaseDataset
  52. if isinstance(dataset, dict):
  53. self.dataset = DATASETS.build(dataset)
  54. elif isinstance(dataset, BaseDataset):
  55. self.dataset = dataset
  56. else:
  57. raise TypeError(
  58. 'elements in datasets sequence should be config or '
  59. f'`BaseDataset` instance, but got {type(dataset)}')
  60. self._metainfo = self.dataset.metainfo
  61. if hasattr(self.dataset, 'flag'):
  62. self.flag = self.dataset.flag
  63. self.num_samples = len(self.dataset)
  64. self.max_refetch = max_refetch
  65. self._fully_initialized = False
  66. if not lazy_init:
  67. self.full_init()
  68. @property
  69. def metainfo(self) -> dict:
  70. """Get the meta information of the multi-image-mixed dataset.
  71. Returns:
  72. dict: The meta information of multi-image-mixed dataset.
  73. """
  74. return copy.deepcopy(self._metainfo)
  75. def full_init(self):
  76. """Loop to ``full_init`` each dataset."""
  77. if self._fully_initialized:
  78. return
  79. self.dataset.full_init()
  80. self._ori_len = len(self.dataset)
  81. self._fully_initialized = True
  82. @force_full_init
  83. def get_data_info(self, idx: int) -> dict:
  84. """Get annotation by index.
  85. Args:
  86. idx (int): Global index of ``ConcatDataset``.
  87. Returns:
  88. dict: The idx-th annotation of the datasets.
  89. """
  90. return self.dataset.get_data_info(idx)
  91. @force_full_init
  92. def __len__(self):
  93. return self.num_samples
  94. def __getitem__(self, idx):
  95. results = copy.deepcopy(self.dataset[idx])
  96. for (transform, transform_type) in zip(self.pipeline,
  97. self.pipeline_types):
  98. if self._skip_type_keys is not None and \
  99. transform_type in self._skip_type_keys:
  100. continue
  101. if hasattr(transform, 'get_indexes'):
  102. for i in range(self.max_refetch):
  103. # Make sure the results passed the loading pipeline
  104. # of the original dataset is not None.
  105. indexes = transform.get_indexes(self.dataset)
  106. if not isinstance(indexes, collections.abc.Sequence):
  107. indexes = [indexes]
  108. mix_results = [
  109. copy.deepcopy(self.dataset[index]) for index in indexes
  110. ]
  111. if None not in mix_results:
  112. results['mix_results'] = mix_results
  113. break
  114. else:
  115. raise RuntimeError(
  116. 'The loading pipeline of the original dataset'
  117. ' always return None. Please check the correctness '
  118. 'of the dataset and its pipeline.')
  119. for i in range(self.max_refetch):
  120. # To confirm the results passed the training pipeline
  121. # of the wrapper is not None.
  122. updated_results = transform(copy.deepcopy(results))
  123. if updated_results is not None:
  124. results = updated_results
  125. break
  126. else:
  127. raise RuntimeError(
  128. 'The training pipeline of the dataset wrapper'
  129. ' always return None.Please check the correctness '
  130. 'of the dataset and its pipeline.')
  131. if 'mix_results' in results:
  132. results.pop('mix_results')
  133. return results
  134. def update_skip_type_keys(self, skip_type_keys):
  135. """Update skip_type_keys. It is called by an external hook.
  136. Args:
  137. skip_type_keys (list[str], optional): Sequence of type
  138. string to be skip pipeline.
  139. """
  140. assert all([
  141. isinstance(skip_type_key, str) for skip_type_key in skip_type_keys
  142. ])
  143. self._skip_type_keys = skip_type_keys