test_augment_wrappers.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import unittest
  4. from mmdet.datasets.transforms import (AutoAugment, AutoContrast, Brightness,
  5. Color, Contrast, Equalize, Invert,
  6. Posterize, RandAugment, Rotate,
  7. Sharpness, ShearX, ShearY, Solarize,
  8. SolarizeAdd, TranslateX, TranslateY)
  9. from mmdet.utils import register_all_modules
  10. from .utils import check_result_same, construct_toy_data
  11. register_all_modules()
  12. class TestAutoAugment(unittest.TestCase):
  13. def setUp(self):
  14. """Setup the model and optimizer which are used in every test method.
  15. TestCase calls functions in this order: setUp() -> testMethod() ->
  16. tearDown() -> cleanUp()
  17. """
  18. self.check_keys = ('img', 'gt_bboxes', 'gt_bboxes_labels', 'gt_masks',
  19. 'gt_ignore_flags', 'gt_seg_map',
  20. 'homography_matrix')
  21. self.results_mask = construct_toy_data(poly2mask=True)
  22. self.img_fill_val = (104, 116, 124)
  23. self.seg_ignore_label = 255
  24. def test_autoaugment(self):
  25. # test AutoAugment equipped with Shear
  26. policies = [[
  27. dict(type='ShearX', prob=1.0, level=3, reversal_prob=0.0),
  28. dict(type='ShearY', prob=1.0, level=7, reversal_prob=1.0)
  29. ]]
  30. transform_auto = AutoAugment(policies=policies)
  31. results_auto = transform_auto(copy.deepcopy(self.results_mask))
  32. transform_shearx = ShearX(prob=1.0, level=3, reversal_prob=0.0)
  33. transform_sheary = ShearY(prob=1.0, level=7, reversal_prob=1.0)
  34. results_sheared = transform_sheary(
  35. transform_shearx(copy.deepcopy(self.results_mask)))
  36. check_result_same(results_sheared, results_auto, self.check_keys)
  37. # test AutoAugment equipped with Rotate
  38. policies = [[
  39. dict(type='Rotate', prob=1.0, level=10, reversal_prob=0.0),
  40. ]]
  41. transform_auto = AutoAugment(policies=policies)
  42. results_auto = transform_auto(copy.deepcopy(self.results_mask))
  43. transform_rotate = Rotate(prob=1.0, level=10, reversal_prob=0.0)
  44. results_rotated = transform_rotate(copy.deepcopy(self.results_mask))
  45. check_result_same(results_rotated, results_auto, self.check_keys)
  46. # test AutoAugment equipped with Translate
  47. policies = [[
  48. dict(
  49. type='TranslateX',
  50. prob=1.0,
  51. level=10,
  52. max_mag=1.0,
  53. reversal_prob=0.0),
  54. dict(
  55. type='TranslateY',
  56. prob=1.0,
  57. level=10,
  58. max_mag=1.0,
  59. reversal_prob=1.0)
  60. ]]
  61. transform_auto = AutoAugment(policies=policies)
  62. results_auto = transform_auto(copy.deepcopy(self.results_mask))
  63. transform_translatex = TranslateX(
  64. prob=1.0, level=10, max_mag=1.0, reversal_prob=0.0)
  65. transform_translatey = TranslateY(
  66. prob=1.0, level=10, max_mag=1.0, reversal_prob=1.0)
  67. results_translated = transform_translatey(
  68. transform_translatex(copy.deepcopy(self.results_mask)))
  69. check_result_same(results_translated, results_auto, self.check_keys)
  70. # test AutoAugment equipped with Brightness
  71. policies = [[
  72. dict(type='Brightness', prob=1.0, level=3),
  73. ]]
  74. transform_auto = AutoAugment(policies=policies)
  75. results_auto = transform_auto(copy.deepcopy(self.results_mask))
  76. transform_brightness = Brightness(prob=1.0, level=3)
  77. results_brightness = transform_brightness(
  78. copy.deepcopy(self.results_mask))
  79. check_result_same(results_brightness, results_auto, self.check_keys)
  80. # test AutoAugment equipped with Color
  81. policies = [[
  82. dict(type='Color', prob=1.0, level=3),
  83. ]]
  84. transform_auto = AutoAugment(policies=policies)
  85. results_auto = transform_auto(copy.deepcopy(self.results_mask))
  86. transform_color = Color(prob=1.0, level=3)
  87. results_colored = transform_color(copy.deepcopy(self.results_mask))
  88. check_result_same(results_colored, results_auto, self.check_keys)
  89. # test AutoAugment equipped with Contrast
  90. policies = [[
  91. dict(type='Contrast', prob=1.0, level=3),
  92. ]]
  93. transform_auto = AutoAugment(policies=policies)
  94. results_auto = transform_auto(copy.deepcopy(self.results_mask))
  95. transform_contrast = Contrast(prob=1.0, level=3)
  96. results_contrasted = transform_contrast(
  97. copy.deepcopy(self.results_mask))
  98. check_result_same(results_contrasted, results_auto, self.check_keys)
  99. # test AutoAugment equipped with Sharpness
  100. policies = [[
  101. dict(type='Sharpness', prob=1.0, level=3),
  102. ]]
  103. transform_auto = AutoAugment(policies=policies)
  104. results_auto = transform_auto(copy.deepcopy(self.results_mask))
  105. transform_sharpness = Sharpness(prob=1.0, level=3)
  106. results_sharpness = transform_sharpness(
  107. copy.deepcopy(self.results_mask))
  108. check_result_same(results_sharpness, results_auto, self.check_keys)
  109. # test AutoAugment equipped with Solarize
  110. policies = [[
  111. dict(type='Solarize', prob=1.0, level=3),
  112. ]]
  113. transform_auto = AutoAugment(policies=policies)
  114. results_auto = transform_auto(copy.deepcopy(self.results_mask))
  115. transform_solarize = Solarize(prob=1.0, level=3)
  116. results_solarized = transform_solarize(
  117. copy.deepcopy(self.results_mask))
  118. check_result_same(results_solarized, results_auto, self.check_keys)
  119. # test AutoAugment equipped with SolarizeAdd
  120. policies = [[
  121. dict(type='SolarizeAdd', prob=1.0, level=3),
  122. ]]
  123. transform_auto = AutoAugment(policies=policies)
  124. results_auto = transform_auto(copy.deepcopy(self.results_mask))
  125. transform_solarizeadd = SolarizeAdd(prob=1.0, level=3)
  126. results_solarizeadded = transform_solarizeadd(
  127. copy.deepcopy(self.results_mask))
  128. check_result_same(results_solarizeadded, results_auto, self.check_keys)
  129. # test AutoAugment equipped with Posterize
  130. policies = [[
  131. dict(type='Posterize', prob=1.0, level=3),
  132. ]]
  133. transform_auto = AutoAugment(policies=policies)
  134. results_auto = transform_auto(copy.deepcopy(self.results_mask))
  135. transform_posterize = Posterize(prob=1.0, level=3)
  136. results_posterized = transform_posterize(
  137. copy.deepcopy(self.results_mask))
  138. check_result_same(results_posterized, results_auto, self.check_keys)
  139. # test AutoAugment equipped with Equalize
  140. policies = [[
  141. dict(type='Equalize', prob=1.0),
  142. ]]
  143. transform_auto = AutoAugment(policies=policies)
  144. results_auto = transform_auto(copy.deepcopy(self.results_mask))
  145. transform_equalize = Equalize(prob=1.0)
  146. results_equalized = transform_equalize(
  147. copy.deepcopy(self.results_mask))
  148. check_result_same(results_equalized, results_auto, self.check_keys)
  149. # test AutoAugment equipped with AutoContrast
  150. policies = [[
  151. dict(type='AutoContrast', prob=1.0),
  152. ]]
  153. transform_auto = AutoAugment(policies=policies)
  154. results_auto = transform_auto(copy.deepcopy(self.results_mask))
  155. transform_autocontrast = AutoContrast(prob=1.0)
  156. results_autocontrast = transform_autocontrast(
  157. copy.deepcopy(self.results_mask))
  158. check_result_same(results_autocontrast, results_auto, self.check_keys)
  159. # test AutoAugment equipped with Invert
  160. policies = [[
  161. dict(type='Invert', prob=1.0),
  162. ]]
  163. transform_auto = AutoAugment(policies=policies)
  164. results_auto = transform_auto(copy.deepcopy(self.results_mask))
  165. transform_invert = Invert(prob=1.0)
  166. results_inverted = transform_invert(copy.deepcopy(self.results_mask))
  167. check_result_same(results_inverted, results_auto, self.check_keys)
  168. # test AutoAugment equipped with default policies
  169. transform_auto = AutoAugment()
  170. transform_auto(copy.deepcopy(self.results_mask))
  171. def test_repr(self):
  172. policies = [[
  173. dict(type='Rotate', prob=1.0, level=10, reversal_prob=0.0),
  174. dict(type='Invert', prob=1.0),
  175. ]]
  176. transform = AutoAugment(policies=policies)
  177. self.assertEqual(
  178. repr(transform), ('AutoAugment('
  179. 'policies=[['
  180. "{'type': 'Rotate', 'prob': 1.0, "
  181. "'level': 10, 'reversal_prob': 0.0}, "
  182. "{'type': 'Invert', 'prob': 1.0}]], "
  183. 'prob=None)'))
  184. class TestRandAugment(unittest.TestCase):
  185. def setUp(self):
  186. """Setup the model and optimizer which are used in every test method.
  187. TestCase calls functions in this order: setUp() -> testMethod() ->
  188. tearDown() -> cleanUp()
  189. """
  190. self.check_keys = ('img', 'gt_bboxes', 'gt_bboxes_labels', 'gt_masks',
  191. 'gt_ignore_flags', 'gt_seg_map',
  192. 'homography_matrix')
  193. self.results_mask = construct_toy_data(poly2mask=True)
  194. self.img_fill_val = (104, 116, 124)
  195. self.seg_ignore_label = 255
  196. def test_randaugment(self):
  197. # test RandAugment equipped with Rotate
  198. aug_space = [[
  199. dict(type='Rotate', prob=1.0, level=10, reversal_prob=0.0)
  200. ]]
  201. transform_rand = RandAugment(aug_space=aug_space, aug_num=1)
  202. results_rand = transform_rand(copy.deepcopy(self.results_mask))
  203. transform_rotate = Rotate(prob=1.0, level=10, reversal_prob=0.0)
  204. results_rotated = transform_rotate(copy.deepcopy(self.results_mask))
  205. check_result_same(results_rotated, results_rand, self.check_keys)
  206. # test RandAugment equipped with default augmentation space
  207. transform_rand = RandAugment()
  208. transform_rand(copy.deepcopy(self.results_mask))
  209. def test_repr(self):
  210. aug_space = [
  211. [dict(type='Rotate')],
  212. [dict(type='Invert')],
  213. ]
  214. transform = RandAugment(aug_space=aug_space)
  215. self.assertEqual(
  216. repr(transform), ('RandAugment('
  217. 'aug_space=['
  218. "[{'type': 'Rotate'}], "
  219. "[{'type': 'Invert'}]], "
  220. 'aug_num=2, '
  221. 'prob=None)'))