test_replace_cfg_vals.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import os.path as osp
  2. import tempfile
  3. from copy import deepcopy
  4. import pytest
  5. from mmengine.config import Config
  6. from mmdet.utils import replace_cfg_vals
  7. def test_replace_cfg_vals():
  8. temp_file = tempfile.NamedTemporaryFile()
  9. cfg_path = f'{temp_file.name}.py'
  10. with open(cfg_path, 'w') as f:
  11. f.write('configs')
  12. ori_cfg_dict = dict()
  13. ori_cfg_dict['cfg_name'] = osp.basename(temp_file.name)
  14. ori_cfg_dict['work_dir'] = 'work_dirs/${cfg_name}/${percent}/${fold}'
  15. ori_cfg_dict['percent'] = 5
  16. ori_cfg_dict['fold'] = 1
  17. ori_cfg_dict['model_wrapper'] = dict(
  18. type='SoftTeacher', detector='${model}')
  19. ori_cfg_dict['model'] = dict(
  20. type='FasterRCNN',
  21. backbone=dict(type='ResNet'),
  22. neck=dict(type='FPN'),
  23. rpn_head=dict(type='RPNHead'),
  24. roi_head=dict(type='StandardRoIHead'),
  25. train_cfg=dict(
  26. rpn=dict(
  27. assigner=dict(type='MaxIoUAssigner'),
  28. sampler=dict(type='RandomSampler'),
  29. ),
  30. rpn_proposal=dict(nms=dict(type='nms', iou_threshold=0.7)),
  31. rcnn=dict(
  32. assigner=dict(type='MaxIoUAssigner'),
  33. sampler=dict(type='RandomSampler'),
  34. ),
  35. ),
  36. test_cfg=dict(
  37. rpn=dict(nms=dict(type='nms', iou_threshold=0.7)),
  38. rcnn=dict(nms=dict(type='nms', iou_threshold=0.5)),
  39. ),
  40. )
  41. ori_cfg_dict['iou_threshold'] = dict(
  42. rpn_proposal_nms='${model.train_cfg.rpn_proposal.nms.iou_threshold}',
  43. test_rpn_nms='${model.test_cfg.rpn.nms.iou_threshold}',
  44. test_rcnn_nms='${model.test_cfg.rcnn.nms.iou_threshold}',
  45. )
  46. ori_cfg_dict['str'] = 'Hello, world!'
  47. ori_cfg_dict['dict'] = {'Hello': 'world!'}
  48. ori_cfg_dict['list'] = [
  49. 'Hello, world!',
  50. ]
  51. ori_cfg_dict['tuple'] = ('Hello, world!', )
  52. ori_cfg_dict['test_str'] = 'xxx${str}xxx'
  53. ori_cfg = Config(ori_cfg_dict, filename=cfg_path)
  54. updated_cfg = replace_cfg_vals(deepcopy(ori_cfg))
  55. assert updated_cfg.work_dir \
  56. == f'work_dirs/{osp.basename(temp_file.name)}/5/1'
  57. assert updated_cfg.model.detector == ori_cfg.model
  58. assert updated_cfg.iou_threshold.rpn_proposal_nms \
  59. == ori_cfg.model.train_cfg.rpn_proposal.nms.iou_threshold
  60. assert updated_cfg.test_str == 'xxxHello, world!xxx'
  61. ori_cfg_dict['test_dict'] = 'xxx${dict}xxx'
  62. ori_cfg_dict['test_list'] = 'xxx${list}xxx'
  63. ori_cfg_dict['test_tuple'] = 'xxx${tuple}xxx'
  64. with pytest.raises(AssertionError):
  65. cfg = deepcopy(ori_cfg)
  66. cfg['test_dict'] = 'xxx${dict}xxx'
  67. updated_cfg = replace_cfg_vals(cfg)
  68. with pytest.raises(AssertionError):
  69. cfg = deepcopy(ori_cfg)
  70. cfg['test_list'] = 'xxx${list}xxx'
  71. updated_cfg = replace_cfg_vals(cfg)
  72. with pytest.raises(AssertionError):
  73. cfg = deepcopy(ori_cfg)
  74. cfg['test_tuple'] = 'xxx${tuple}xxx'
  75. updated_cfg = replace_cfg_vals(cfg)