test_semi_base.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. from mmengine.registry import MODELS
  4. from parameterized import parameterized
  5. from mmdet.testing import get_detector_cfg
  6. from mmdet.utils import register_all_modules
  7. register_all_modules()
  8. class TestSemiBase(TestCase):
  9. @parameterized.expand([
  10. 'soft_teacher/'
  11. 'soft-teacher_faster-rcnn_r50-caffe_fpn_180k_semi-0.1-coco.py',
  12. ])
  13. def test_init(self, cfg_file):
  14. model = get_detector_cfg(cfg_file)
  15. # backbone convert to ResNet18
  16. model.detector.backbone.depth = 18
  17. model.detector.neck.in_channels = [64, 128, 256, 512]
  18. model.detector.backbone.init_cfg = None
  19. model = MODELS.build(model)
  20. self.assertTrue(model.teacher.backbone)
  21. self.assertTrue(model.teacher.neck)
  22. self.assertTrue(model.teacher.rpn_head)
  23. self.assertTrue(model.teacher.roi_head)
  24. self.assertTrue(model.student.backbone)
  25. self.assertTrue(model.student.neck)
  26. self.assertTrue(model.student.rpn_head)
  27. self.assertTrue(model.student.roi_head)