augment_wrappers.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Union
  3. import numpy as np
  4. from mmcv.transforms import RandomChoice
  5. from mmcv.transforms.utils import cache_randomness
  6. from mmengine.config import ConfigDict
  7. from mmdet.registry import TRANSFORMS
  8. # AutoAugment uses reinforcement learning to search for
  9. # some widely useful data augmentation strategies,
  10. # here we provide AUTOAUG_POLICIES_V0.
  11. # For AUTOAUG_POLICIES_V0, each tuple is an augmentation
  12. # operation of the form (operation, probability, magnitude).
  13. # Each element in policies is a policy that will be applied
  14. # sequentially on the image.
  15. # RandAugment defines a data augmentation search space, RANDAUG_SPACE,
  16. # sampling 1~3 data augmentations each time, and
  17. # setting the magnitude of each data augmentation randomly,
  18. # which will be applied sequentially on the image.
  19. _MAX_LEVEL = 10
  20. AUTOAUG_POLICIES_V0 = [
  21. [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
  22. [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
  23. [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
  24. [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
  25. [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
  26. [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
  27. [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
  28. [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
  29. [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
  30. [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
  31. [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
  32. [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
  33. [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
  34. [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
  35. [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
  36. [('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)],
  37. [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
  38. [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
  39. [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
  40. [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
  41. [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
  42. [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
  43. [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
  44. [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
  45. [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
  46. ]
  47. def policies_v0():
  48. """Autoaugment policies that was used in AutoAugment Paper."""
  49. policies = list()
  50. for policy_args in AUTOAUG_POLICIES_V0:
  51. policy = list()
  52. for args in policy_args:
  53. policy.append(dict(type=args[0], prob=args[1], level=args[2]))
  54. policies.append(policy)
  55. return policies
  56. RANDAUG_SPACE = [[dict(type='AutoContrast')], [dict(type='Equalize')],
  57. [dict(type='Invert')], [dict(type='Rotate')],
  58. [dict(type='Posterize')], [dict(type='Solarize')],
  59. [dict(type='SolarizeAdd')], [dict(type='Color')],
  60. [dict(type='Contrast')], [dict(type='Brightness')],
  61. [dict(type='Sharpness')], [dict(type='ShearX')],
  62. [dict(type='ShearY')], [dict(type='TranslateX')],
  63. [dict(type='TranslateY')]]
  64. def level_to_mag(level: Optional[int], min_mag: float,
  65. max_mag: float) -> float:
  66. """Map from level to magnitude."""
  67. if level is None:
  68. return round(np.random.rand() * (max_mag - min_mag) + min_mag, 1)
  69. else:
  70. return round(level / _MAX_LEVEL * (max_mag - min_mag) + min_mag, 1)
  71. @TRANSFORMS.register_module()
  72. class AutoAugment(RandomChoice):
  73. """Auto augmentation.
  74. This data augmentation is proposed in `AutoAugment: Learning
  75. Augmentation Policies from Data <https://arxiv.org/abs/1805.09501>`_
  76. and in `Learning Data Augmentation Strategies for Object Detection
  77. <https://arxiv.org/pdf/1906.11172>`_.
  78. Required Keys:
  79. - img
  80. - gt_bboxes (BaseBoxes[torch.float32]) (optional)
  81. - gt_bboxes_labels (np.int64) (optional)
  82. - gt_masks (BitmapMasks | PolygonMasks) (optional)
  83. - gt_ignore_flags (bool) (optional)
  84. - gt_seg_map (np.uint8) (optional)
  85. Modified Keys:
  86. - img
  87. - img_shape
  88. - gt_bboxes
  89. - gt_bboxes_labels
  90. - gt_masks
  91. - gt_ignore_flags
  92. - gt_seg_map
  93. Added Keys:
  94. - homography_matrix
  95. Args:
  96. policies (List[List[Union[dict, ConfigDict]]]):
  97. The policies of auto augmentation.Each policy in ``policies``
  98. is a specific augmentation policy, and is composed by several
  99. augmentations. When AutoAugment is called, a random policy in
  100. ``policies`` will be selected to augment images.
  101. Defaults to policy_v0().
  102. prob (list[float], optional): The probabilities associated
  103. with each policy. The length should be equal to the policy
  104. number and the sum should be 1. If not given, a uniform
  105. distribution will be assumed. Defaults to None.
  106. Examples:
  107. >>> policies = [
  108. >>> [
  109. >>> dict(type='Sharpness', prob=0.0, level=8),
  110. >>> dict(type='ShearX', prob=0.4, level=0,)
  111. >>> ],
  112. >>> [
  113. >>> dict(type='Rotate', prob=0.6, level=10),
  114. >>> dict(type='Color', prob=1.0, level=6)
  115. >>> ]
  116. >>> ]
  117. >>> augmentation = AutoAugment(policies)
  118. >>> img = np.ones(100, 100, 3)
  119. >>> gt_bboxes = np.ones(10, 4)
  120. >>> results = dict(img=img, gt_bboxes=gt_bboxes)
  121. >>> results = augmentation(results)
  122. """
  123. def __init__(self,
  124. policies: List[List[Union[dict, ConfigDict]]] = policies_v0(),
  125. prob: Optional[List[float]] = None) -> None:
  126. assert isinstance(policies, list) and len(policies) > 0, \
  127. 'Policies must be a non-empty list.'
  128. for policy in policies:
  129. assert isinstance(policy, list) and len(policy) > 0, \
  130. 'Each policy in policies must be a non-empty list.'
  131. for augment in policy:
  132. assert isinstance(augment, dict) and 'type' in augment, \
  133. 'Each specific augmentation must be a dict with key' \
  134. ' "type".'
  135. super().__init__(transforms=policies, prob=prob)
  136. self.policies = policies
  137. def __repr__(self) -> str:
  138. return f'{self.__class__.__name__}(policies={self.policies}, ' \
  139. f'prob={self.prob})'
  140. @TRANSFORMS.register_module()
  141. class RandAugment(RandomChoice):
  142. """Rand augmentation.
  143. This data augmentation is proposed in `RandAugment:
  144. Practical automated data augmentation with a reduced
  145. search space <https://arxiv.org/abs/1909.13719>`_.
  146. Required Keys:
  147. - img
  148. - gt_bboxes (BaseBoxes[torch.float32]) (optional)
  149. - gt_bboxes_labels (np.int64) (optional)
  150. - gt_masks (BitmapMasks | PolygonMasks) (optional)
  151. - gt_ignore_flags (bool) (optional)
  152. - gt_seg_map (np.uint8) (optional)
  153. Modified Keys:
  154. - img
  155. - img_shape
  156. - gt_bboxes
  157. - gt_bboxes_labels
  158. - gt_masks
  159. - gt_ignore_flags
  160. - gt_seg_map
  161. Added Keys:
  162. - homography_matrix
  163. Args:
  164. aug_space (List[List[Union[dict, ConfigDict]]]): The augmentation space
  165. of rand augmentation. Each augmentation transform in ``aug_space``
  166. is a specific transform, and is composed by several augmentations.
  167. When RandAugment is called, a random transform in ``aug_space``
  168. will be selected to augment images. Defaults to aug_space.
  169. aug_num (int): Number of augmentation to apply equentially.
  170. Defaults to 2.
  171. prob (list[float], optional): The probabilities associated with
  172. each augmentation. The length should be equal to the
  173. augmentation space and the sum should be 1. If not given,
  174. a uniform distribution will be assumed. Defaults to None.
  175. Examples:
  176. >>> aug_space = [
  177. >>> dict(type='Sharpness'),
  178. >>> dict(type='ShearX'),
  179. >>> dict(type='Color'),
  180. >>> ],
  181. >>> augmentation = RandAugment(aug_space)
  182. >>> img = np.ones(100, 100, 3)
  183. >>> gt_bboxes = np.ones(10, 4)
  184. >>> results = dict(img=img, gt_bboxes=gt_bboxes)
  185. >>> results = augmentation(results)
  186. """
  187. def __init__(self,
  188. aug_space: List[Union[dict, ConfigDict]] = RANDAUG_SPACE,
  189. aug_num: int = 2,
  190. prob: Optional[List[float]] = None) -> None:
  191. assert isinstance(aug_space, list) and len(aug_space) > 0, \
  192. 'Augmentation space must be a non-empty list.'
  193. for aug in aug_space:
  194. assert isinstance(aug, list) and len(aug) == 1, \
  195. 'Each augmentation in aug_space must be a list.'
  196. for transform in aug:
  197. assert isinstance(transform, dict) and 'type' in transform, \
  198. 'Each specific transform must be a dict with key' \
  199. ' "type".'
  200. super().__init__(transforms=aug_space, prob=prob)
  201. self.aug_space = aug_space
  202. self.aug_num = aug_num
  203. @cache_randomness
  204. def random_pipeline_index(self):
  205. indices = np.arange(len(self.transforms))
  206. return np.random.choice(
  207. indices, self.aug_num, p=self.prob, replace=False)
  208. def transform(self, results: dict) -> dict:
  209. """Transform function to use RandAugment.
  210. Args:
  211. results (dict): Result dict from loading pipeline.
  212. Returns:
  213. dict: Result dict with RandAugment.
  214. """
  215. for idx in self.random_pipeline_index():
  216. results = self.transforms[idx](results)
  217. return results
  218. def __repr__(self) -> str:
  219. return f'{self.__class__.__name__}(' \
  220. f'aug_space={self.aug_space}, '\
  221. f'aug_num={self.aug_num}, ' \
  222. f'prob={self.prob})'