test_tta.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. from unittest import TestCase
  4. import mmcv
  5. import pytest
  6. from mmdet.datasets.transforms import * # noqa
  7. from mmdet.registry import TRANSFORMS
  8. class TestMuitiScaleFlipAug(TestCase):
  9. def test_exception(self):
  10. with pytest.raises(TypeError):
  11. tta_transform = dict(
  12. type='TestTimeAug',
  13. transforms=[dict(type='Resize', keep_ratio=False)],
  14. )
  15. TRANSFORMS.build(tta_transform)
  16. def test_multi_scale_flip_aug(self):
  17. tta_transform = dict(
  18. type='TestTimeAug',
  19. transforms=[[
  20. dict(type='Resize', scale=scale, keep_ratio=False)
  21. for scale in [(256, 256), (512, 512), (1024, 1024)]
  22. ],
  23. [
  24. dict(
  25. type='mmdet.PackDetInputs',
  26. meta_keys=('img_id', 'img_path', 'ori_shape',
  27. 'img_shape', 'scale_factor'))
  28. ]])
  29. tta_module = TRANSFORMS.build(tta_transform)
  30. results = dict()
  31. img = mmcv.imread(
  32. osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
  33. results['img_id'] = '1'
  34. results['img_path'] = 'data/color.jpg'
  35. results['img'] = img
  36. results['ori_shape'] = img.shape
  37. results['ori_height'] = img.shape[0]
  38. results['ori_width'] = img.shape[1]
  39. # Set initial values for default meta_keys
  40. results['pad_shape'] = img.shape
  41. results['scale_factor'] = 1.0
  42. tta_results = tta_module(results.copy())
  43. assert [img.shape
  44. for img in tta_results['inputs']] == [(3, 256, 256),
  45. (3, 512, 512),
  46. (3, 1024, 1024)]
  47. tta_transform = dict(
  48. type='TestTimeAug',
  49. transforms=[
  50. [
  51. dict(type='Resize', scale=scale, keep_ratio=False)
  52. for scale in [(256, 256), (512, 512), (1024, 1024)]
  53. ],
  54. [
  55. dict(type='RandomFlip', prob=0., direction='horizontal'),
  56. dict(type='RandomFlip', prob=1., direction='horizontal')
  57. ],
  58. [
  59. dict(
  60. type='mmdet.PackDetInputs',
  61. meta_keys=('img_id', 'img_path', 'ori_shape',
  62. 'img_shape', 'scale_factor', 'flip',
  63. 'flip_direction'))
  64. ]
  65. ])
  66. tta_module = TRANSFORMS.build(tta_transform)
  67. tta_results: dict = tta_module(results.copy())
  68. assert [img.shape
  69. for img in tta_results['inputs']] == [(3, 256, 256),
  70. (3, 256, 256),
  71. (3, 512, 512),
  72. (3, 512, 512),
  73. (3, 1024, 1024),
  74. (3, 1024, 1024)]
  75. assert [
  76. data_sample.metainfo['flip']
  77. for data_sample in tta_results['data_samples']
  78. ] == [False, True, False, True, False, True]
  79. tta_transform = dict(
  80. type='TestTimeAug',
  81. transforms=[[
  82. dict(type='Resize', scale=(512, 512), keep_ratio=False)
  83. ],
  84. [
  85. dict(
  86. type='mmdet.PackDetInputs',
  87. meta_keys=('img_id', 'img_path', 'ori_shape',
  88. 'img_shape', 'scale_factor'))
  89. ]])
  90. tta_module = TRANSFORMS.build(tta_transform)
  91. tta_results = tta_module(results.copy())
  92. assert [tta_results['inputs'][0].shape] == [(3, 512, 512)]
  93. tta_transform = dict(
  94. type='TestTimeAug',
  95. transforms=[
  96. [dict(type='Resize', scale=(512, 512), keep_ratio=False)],
  97. [
  98. dict(type='RandomFlip', prob=0., direction='horizontal'),
  99. dict(type='RandomFlip', prob=1., direction='horizontal')
  100. ],
  101. [
  102. dict(
  103. type='mmdet.PackDetInputs',
  104. meta_keys=('img_id', 'img_path', 'ori_shape',
  105. 'img_shape', 'scale_factor', 'flip',
  106. 'flip_direction'))
  107. ]
  108. ])
  109. tta_module = TRANSFORMS.build(tta_transform)
  110. tta_results = tta_module(results.copy())
  111. assert [img.shape for img in tta_results['inputs']] == [(3, 512, 512),
  112. (3, 512, 512)]
  113. assert [
  114. data_sample.metainfo['flip']
  115. for data_sample in tta_results['data_samples']
  116. ] == [False, True]
  117. tta_transform = dict(
  118. type='TestTimeAug',
  119. transforms=[[
  120. dict(type='Resize', scale_factor=r, keep_ratio=False)
  121. for r in [0.5, 1.0, 2.0]
  122. ],
  123. [
  124. dict(
  125. type='mmdet.PackDetInputs',
  126. meta_keys=('img_id', 'img_path', 'ori_shape',
  127. 'img_shape', 'scale_factor'))
  128. ]])
  129. tta_module = TRANSFORMS.build(tta_transform)
  130. tta_results = tta_module(results.copy())
  131. assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256),
  132. (3, 288, 512),
  133. (3, 576, 1024)]
  134. tta_transform = dict(
  135. type='TestTimeAug',
  136. transforms=[
  137. [
  138. dict(type='Resize', scale_factor=r, keep_ratio=True)
  139. for r in [0.5, 1.0, 2.0]
  140. ],
  141. [
  142. dict(type='RandomFlip', prob=0., direction='horizontal'),
  143. dict(type='RandomFlip', prob=1., direction='horizontal')
  144. ],
  145. [
  146. dict(
  147. type='mmdet.PackDetInputs',
  148. meta_keys=('img_id', 'img_path', 'ori_shape',
  149. 'img_shape', 'scale_factor', 'flip',
  150. 'flip_direction'))
  151. ]
  152. ])
  153. tta_module = TRANSFORMS.build(tta_transform)
  154. tta_results = tta_module(results.copy())
  155. assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256),
  156. (3, 144, 256),
  157. (3, 288, 512),
  158. (3, 288, 512),
  159. (3, 576, 1024),
  160. (3, 576, 1024)]
  161. assert [
  162. data_sample.metainfo['flip']
  163. for data_sample in tta_results['data_samples']
  164. ] == [False, True, False, True, False, True]