1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- import os.path as osp
- import tempfile
- from copy import deepcopy
- import pytest
- from mmengine.config import Config
- from mmdet.utils import replace_cfg_vals
- def test_replace_cfg_vals():
- temp_file = tempfile.NamedTemporaryFile()
- cfg_path = f'{temp_file.name}.py'
- with open(cfg_path, 'w') as f:
- f.write('configs')
- ori_cfg_dict = dict()
- ori_cfg_dict['cfg_name'] = osp.basename(temp_file.name)
- ori_cfg_dict['work_dir'] = 'work_dirs/${cfg_name}/${percent}/${fold}'
- ori_cfg_dict['percent'] = 5
- ori_cfg_dict['fold'] = 1
- ori_cfg_dict['model_wrapper'] = dict(
- type='SoftTeacher', detector='${model}')
- ori_cfg_dict['model'] = dict(
- type='FasterRCNN',
- backbone=dict(type='ResNet'),
- neck=dict(type='FPN'),
- rpn_head=dict(type='RPNHead'),
- roi_head=dict(type='StandardRoIHead'),
- train_cfg=dict(
- rpn=dict(
- assigner=dict(type='MaxIoUAssigner'),
- sampler=dict(type='RandomSampler'),
- ),
- rpn_proposal=dict(nms=dict(type='nms', iou_threshold=0.7)),
- rcnn=dict(
- assigner=dict(type='MaxIoUAssigner'),
- sampler=dict(type='RandomSampler'),
- ),
- ),
- test_cfg=dict(
- rpn=dict(nms=dict(type='nms', iou_threshold=0.7)),
- rcnn=dict(nms=dict(type='nms', iou_threshold=0.5)),
- ),
- )
- ori_cfg_dict['iou_threshold'] = dict(
- rpn_proposal_nms='${model.train_cfg.rpn_proposal.nms.iou_threshold}',
- test_rpn_nms='${model.test_cfg.rpn.nms.iou_threshold}',
- test_rcnn_nms='${model.test_cfg.rcnn.nms.iou_threshold}',
- )
- ori_cfg_dict['str'] = 'Hello, world!'
- ori_cfg_dict['dict'] = {'Hello': 'world!'}
- ori_cfg_dict['list'] = [
- 'Hello, world!',
- ]
- ori_cfg_dict['tuple'] = ('Hello, world!', )
- ori_cfg_dict['test_str'] = 'xxx${str}xxx'
- ori_cfg = Config(ori_cfg_dict, filename=cfg_path)
- updated_cfg = replace_cfg_vals(deepcopy(ori_cfg))
- assert updated_cfg.work_dir \
- == f'work_dirs/{osp.basename(temp_file.name)}/5/1'
- assert updated_cfg.model.detector == ori_cfg.model
- assert updated_cfg.iou_threshold.rpn_proposal_nms \
- == ori_cfg.model.train_cfg.rpn_proposal.nms.iou_threshold
- assert updated_cfg.test_str == 'xxxHello, world!xxx'
- ori_cfg_dict['test_dict'] = 'xxx${dict}xxx'
- ori_cfg_dict['test_list'] = 'xxx${list}xxx'
- ori_cfg_dict['test_tuple'] = 'xxx${tuple}xxx'
- with pytest.raises(AssertionError):
- cfg = deepcopy(ori_cfg)
- cfg['test_dict'] = 'xxx${dict}xxx'
- updated_cfg = replace_cfg_vals(cfg)
- with pytest.raises(AssertionError):
- cfg = deepcopy(ori_cfg)
- cfg['test_list'] = 'xxx${list}xxx'
- updated_cfg = replace_cfg_vals(cfg)
- with pytest.raises(AssertionError):
- cfg = deepcopy(ori_cfg)
- cfg['test_tuple'] = 'xxx${tuple}xxx'
- updated_cfg = replace_cfg_vals(cfg)
|