soft-teacher_faster-rcnn_r50-caffe_fpn_180k_semi-0.1-coco.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. _base_ = [
  2. '../_base_/models/faster-rcnn_r50_fpn.py', '../_base_/default_runtime.py',
  3. '../_base_/datasets/semi_coco_detection.py'
  4. ]
  5. detector = _base_.model
  6. detector.data_preprocessor = dict(
  7. type='DetDataPreprocessor',
  8. mean=[103.530, 116.280, 123.675],
  9. std=[1.0, 1.0, 1.0],
  10. bgr_to_rgb=False,
  11. pad_size_divisor=32)
  12. detector.backbone = dict(
  13. type='ResNet',
  14. depth=50,
  15. num_stages=4,
  16. out_indices=(0, 1, 2, 3),
  17. frozen_stages=1,
  18. norm_cfg=dict(type='BN', requires_grad=False),
  19. norm_eval=True,
  20. style='caffe',
  21. init_cfg=dict(
  22. type='Pretrained',
  23. checkpoint='open-mmlab://detectron2/resnet50_caffe'))
  24. model = dict(
  25. _delete_=True,
  26. type='SoftTeacher',
  27. detector=detector,
  28. data_preprocessor=dict(
  29. type='MultiBranchDataPreprocessor',
  30. data_preprocessor=detector.data_preprocessor),
  31. semi_train_cfg=dict(
  32. freeze_teacher=True,
  33. sup_weight=1.0,
  34. unsup_weight=4.0,
  35. pseudo_label_initial_score_thr=0.5,
  36. rpn_pseudo_thr=0.9,
  37. cls_pseudo_thr=0.9,
  38. reg_pseudo_thr=0.02,
  39. jitter_times=10,
  40. jitter_scale=0.06,
  41. min_pseudo_bbox_wh=(1e-2, 1e-2)),
  42. semi_test_cfg=dict(predict_on='teacher'))
  43. # 10% coco train2017 is set as labeled dataset
  44. labeled_dataset = _base_.labeled_dataset
  45. unlabeled_dataset = _base_.unlabeled_dataset
  46. labeled_dataset.ann_file = 'semi_anns/instances_train2017.1@10.json'
  47. unlabeled_dataset.ann_file = 'semi_anns/' \
  48. 'instances_train2017.1@10-unlabeled.json'
  49. unlabeled_dataset.data_prefix = dict(img='train2017/')
  50. train_dataloader = dict(
  51. dataset=dict(datasets=[labeled_dataset, unlabeled_dataset]))
  52. # training schedule for 180k
  53. train_cfg = dict(
  54. type='IterBasedTrainLoop', max_iters=180000, val_interval=5000)
  55. val_cfg = dict(type='TeacherStudentValLoop')
  56. test_cfg = dict(type='TestLoop')
  57. # learning rate policy
  58. param_scheduler = [
  59. dict(
  60. type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500),
  61. dict(
  62. type='MultiStepLR',
  63. begin=0,
  64. end=180000,
  65. by_epoch=False,
  66. milestones=[120000, 160000],
  67. gamma=0.1)
  68. ]
  69. # optimizer
  70. optim_wrapper = dict(
  71. type='OptimWrapper',
  72. optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001))
  73. default_hooks = dict(
  74. checkpoint=dict(by_epoch=False, interval=10000, max_keep_ckpts=2))
  75. log_processor = dict(by_epoch=False)
  76. custom_hooks = [dict(type='MeanTeacherHook')]