ld_r18-gflv1-r101_fpn_1x_coco.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. _base_ = [
  2. '../_base_/datasets/coco_detection.py',
  3. '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
  4. ]
  5. teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/gfl/gfl_r101_fpn_mstrain_2x_coco/gfl_r101_fpn_mstrain_2x_coco_20200629_200126-dd12f847.pth' # noqa
  6. model = dict(
  7. type='KnowledgeDistillationSingleStageDetector',
  8. data_preprocessor=dict(
  9. type='DetDataPreprocessor',
  10. mean=[123.675, 116.28, 103.53],
  11. std=[58.395, 57.12, 57.375],
  12. bgr_to_rgb=True,
  13. pad_size_divisor=32),
  14. teacher_config='configs/gfl/gfl_r101_fpn_ms-2x_coco.py',
  15. teacher_ckpt=teacher_ckpt,
  16. backbone=dict(
  17. type='ResNet',
  18. depth=18,
  19. num_stages=4,
  20. out_indices=(0, 1, 2, 3),
  21. frozen_stages=1,
  22. norm_cfg=dict(type='BN', requires_grad=True),
  23. norm_eval=True,
  24. style='pytorch',
  25. init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18')),
  26. neck=dict(
  27. type='FPN',
  28. in_channels=[64, 128, 256, 512],
  29. out_channels=256,
  30. start_level=1,
  31. add_extra_convs='on_output',
  32. num_outs=5),
  33. bbox_head=dict(
  34. type='LDHead',
  35. num_classes=80,
  36. in_channels=256,
  37. stacked_convs=4,
  38. feat_channels=256,
  39. anchor_generator=dict(
  40. type='AnchorGenerator',
  41. ratios=[1.0],
  42. octave_base_scale=8,
  43. scales_per_octave=1,
  44. strides=[8, 16, 32, 64, 128]),
  45. loss_cls=dict(
  46. type='QualityFocalLoss',
  47. use_sigmoid=True,
  48. beta=2.0,
  49. loss_weight=1.0),
  50. loss_dfl=dict(type='DistributionFocalLoss', loss_weight=0.25),
  51. loss_ld=dict(
  52. type='KnowledgeDistillationKLDivLoss', loss_weight=0.25, T=10),
  53. reg_max=16,
  54. loss_bbox=dict(type='GIoULoss', loss_weight=2.0)),
  55. # training and testing settings
  56. train_cfg=dict(
  57. assigner=dict(type='ATSSAssigner', topk=9),
  58. allowed_border=-1,
  59. pos_weight=-1,
  60. debug=False),
  61. test_cfg=dict(
  62. nms_pre=1000,
  63. min_bbox_size=0,
  64. score_thr=0.05,
  65. nms=dict(type='nms', iou_threshold=0.6),
  66. max_per_img=100))
  67. optim_wrapper = dict(
  68. type='OptimWrapper',
  69. optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001))