mask-rcnn_swin-t-p4-w7_fpn_ms-crop-3x_coco.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. _base_ = [
  2. '../_base_/models/mask-rcnn_r50_fpn.py',
  3. '../_base_/datasets/coco_instance.py',
  4. '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
  5. ]
  6. pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' # noqa
  7. model = dict(
  8. type='MaskRCNN',
  9. backbone=dict(
  10. _delete_=True,
  11. type='SwinTransformer',
  12. embed_dims=96,
  13. depths=[2, 2, 6, 2],
  14. num_heads=[3, 6, 12, 24],
  15. window_size=7,
  16. mlp_ratio=4,
  17. qkv_bias=True,
  18. qk_scale=None,
  19. drop_rate=0.,
  20. attn_drop_rate=0.,
  21. drop_path_rate=0.2,
  22. patch_norm=True,
  23. out_indices=(0, 1, 2, 3),
  24. with_cp=False,
  25. convert_weights=True,
  26. init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
  27. neck=dict(in_channels=[96, 192, 384, 768]))
  28. # augmentation strategy originates from DETR / Sparse RCNN
  29. train_pipeline = [
  30. dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}),
  31. dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
  32. dict(type='RandomFlip', prob=0.5),
  33. dict(
  34. type='RandomChoice',
  35. transforms=[[
  36. dict(
  37. type='RandomChoiceResize',
  38. scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
  39. (608, 1333), (640, 1333), (672, 1333), (704, 1333),
  40. (736, 1333), (768, 1333), (800, 1333)],
  41. keep_ratio=True)
  42. ],
  43. [
  44. dict(
  45. type='RandomChoiceResize',
  46. scales=[(400, 1333), (500, 1333), (600, 1333)],
  47. keep_ratio=True),
  48. dict(
  49. type='RandomCrop',
  50. crop_type='absolute_range',
  51. crop_size=(384, 600),
  52. allow_negative_crop=True),
  53. dict(
  54. type='RandomChoiceResize',
  55. scales=[(480, 1333), (512, 1333), (544, 1333),
  56. (576, 1333), (608, 1333), (640, 1333),
  57. (672, 1333), (704, 1333), (736, 1333),
  58. (768, 1333), (800, 1333)],
  59. keep_ratio=True)
  60. ]]),
  61. dict(type='PackDetInputs')
  62. ]
  63. train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
  64. max_epochs = 36
  65. train_cfg = dict(max_epochs=max_epochs)
  66. # learning rate
  67. param_scheduler = [
  68. dict(
  69. type='LinearLR', start_factor=0.001, by_epoch=False, begin=0,
  70. end=1000),
  71. dict(
  72. type='MultiStepLR',
  73. begin=0,
  74. end=max_epochs,
  75. by_epoch=True,
  76. milestones=[27, 33],
  77. gamma=0.1)
  78. ]
  79. # optimizer
  80. optim_wrapper = dict(
  81. type='OptimWrapper',
  82. paramwise_cfg=dict(
  83. custom_keys={
  84. 'absolute_pos_embed': dict(decay_mult=0.),
  85. 'relative_position_bias_table': dict(decay_mult=0.),
  86. 'norm': dict(decay_mult=0.)
  87. }),
  88. optimizer=dict(
  89. _delete_=True,
  90. type='AdamW',
  91. lr=0.0001,
  92. betas=(0.9, 0.999),
  93. weight_decay=0.05))