123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import os.path as osp
- from unittest import TestCase
- import mmcv
- import pytest
- from mmdet.datasets.transforms import * # noqa
- from mmdet.registry import TRANSFORMS
- class TestMuitiScaleFlipAug(TestCase):
- def test_exception(self):
- with pytest.raises(TypeError):
- tta_transform = dict(
- type='TestTimeAug',
- transforms=[dict(type='Resize', keep_ratio=False)],
- )
- TRANSFORMS.build(tta_transform)
- def test_multi_scale_flip_aug(self):
- tta_transform = dict(
- type='TestTimeAug',
- transforms=[[
- dict(type='Resize', scale=scale, keep_ratio=False)
- for scale in [(256, 256), (512, 512), (1024, 1024)]
- ],
- [
- dict(
- type='mmdet.PackDetInputs',
- meta_keys=('img_id', 'img_path', 'ori_shape',
- 'img_shape', 'scale_factor'))
- ]])
- tta_module = TRANSFORMS.build(tta_transform)
- results = dict()
- img = mmcv.imread(
- osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
- results['img_id'] = '1'
- results['img_path'] = 'data/color.jpg'
- results['img'] = img
- results['ori_shape'] = img.shape
- results['ori_height'] = img.shape[0]
- results['ori_width'] = img.shape[1]
- # Set initial values for default meta_keys
- results['pad_shape'] = img.shape
- results['scale_factor'] = 1.0
- tta_results = tta_module(results.copy())
- assert [img.shape
- for img in tta_results['inputs']] == [(3, 256, 256),
- (3, 512, 512),
- (3, 1024, 1024)]
- tta_transform = dict(
- type='TestTimeAug',
- transforms=[
- [
- dict(type='Resize', scale=scale, keep_ratio=False)
- for scale in [(256, 256), (512, 512), (1024, 1024)]
- ],
- [
- dict(type='RandomFlip', prob=0., direction='horizontal'),
- dict(type='RandomFlip', prob=1., direction='horizontal')
- ],
- [
- dict(
- type='mmdet.PackDetInputs',
- meta_keys=('img_id', 'img_path', 'ori_shape',
- 'img_shape', 'scale_factor', 'flip',
- 'flip_direction'))
- ]
- ])
- tta_module = TRANSFORMS.build(tta_transform)
- tta_results: dict = tta_module(results.copy())
- assert [img.shape
- for img in tta_results['inputs']] == [(3, 256, 256),
- (3, 256, 256),
- (3, 512, 512),
- (3, 512, 512),
- (3, 1024, 1024),
- (3, 1024, 1024)]
- assert [
- data_sample.metainfo['flip']
- for data_sample in tta_results['data_samples']
- ] == [False, True, False, True, False, True]
- tta_transform = dict(
- type='TestTimeAug',
- transforms=[[
- dict(type='Resize', scale=(512, 512), keep_ratio=False)
- ],
- [
- dict(
- type='mmdet.PackDetInputs',
- meta_keys=('img_id', 'img_path', 'ori_shape',
- 'img_shape', 'scale_factor'))
- ]])
- tta_module = TRANSFORMS.build(tta_transform)
- tta_results = tta_module(results.copy())
- assert [tta_results['inputs'][0].shape] == [(3, 512, 512)]
- tta_transform = dict(
- type='TestTimeAug',
- transforms=[
- [dict(type='Resize', scale=(512, 512), keep_ratio=False)],
- [
- dict(type='RandomFlip', prob=0., direction='horizontal'),
- dict(type='RandomFlip', prob=1., direction='horizontal')
- ],
- [
- dict(
- type='mmdet.PackDetInputs',
- meta_keys=('img_id', 'img_path', 'ori_shape',
- 'img_shape', 'scale_factor', 'flip',
- 'flip_direction'))
- ]
- ])
- tta_module = TRANSFORMS.build(tta_transform)
- tta_results = tta_module(results.copy())
- assert [img.shape for img in tta_results['inputs']] == [(3, 512, 512),
- (3, 512, 512)]
- assert [
- data_sample.metainfo['flip']
- for data_sample in tta_results['data_samples']
- ] == [False, True]
- tta_transform = dict(
- type='TestTimeAug',
- transforms=[[
- dict(type='Resize', scale_factor=r, keep_ratio=False)
- for r in [0.5, 1.0, 2.0]
- ],
- [
- dict(
- type='mmdet.PackDetInputs',
- meta_keys=('img_id', 'img_path', 'ori_shape',
- 'img_shape', 'scale_factor'))
- ]])
- tta_module = TRANSFORMS.build(tta_transform)
- tta_results = tta_module(results.copy())
- assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256),
- (3, 288, 512),
- (3, 576, 1024)]
- tta_transform = dict(
- type='TestTimeAug',
- transforms=[
- [
- dict(type='Resize', scale_factor=r, keep_ratio=True)
- for r in [0.5, 1.0, 2.0]
- ],
- [
- dict(type='RandomFlip', prob=0., direction='horizontal'),
- dict(type='RandomFlip', prob=1., direction='horizontal')
- ],
- [
- dict(
- type='mmdet.PackDetInputs',
- meta_keys=('img_id', 'img_path', 'ori_shape',
- 'img_shape', 'scale_factor', 'flip',
- 'flip_direction'))
- ]
- ])
- tta_module = TRANSFORMS.build(tta_transform)
- tta_results = tta_module(results.copy())
- assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256),
- (3, 144, 256),
- (3, 288, 512),
- (3, 288, 512),
- (3, 576, 1024),
- (3, 576, 1024)]
- assert [
- data_sample.metainfo['flip']
- for data_sample in tta_results['data_samples']
- ] == [False, True, False, True, False, True]
|