dino-4scale_r50_8xb2-12e_coco.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. _base_ = [
  2. '../_base_/datasets/coco_detection.py', '../_base_/default_runtime.py'
  3. ]
  4. model = dict(
  5. type='DINO',
  6. num_queries=900, # num_matching_queries
  7. with_box_refine=True,
  8. as_two_stage=True,
  9. data_preprocessor=dict(
  10. type='DetDataPreprocessor',
  11. mean=[123.675, 116.28, 103.53],
  12. std=[58.395, 57.12, 57.375],
  13. bgr_to_rgb=True,
  14. pad_size_divisor=1),
  15. backbone=dict(
  16. type='ResNet',
  17. depth=50,
  18. num_stages=4,
  19. out_indices=(1, 2, 3),
  20. frozen_stages=1,
  21. norm_cfg=dict(type='BN', requires_grad=False),
  22. norm_eval=True,
  23. style='pytorch',
  24. init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
  25. neck=dict(
  26. type='ChannelMapper',
  27. in_channels=[512, 1024, 2048],
  28. kernel_size=1,
  29. out_channels=256,
  30. act_cfg=None,
  31. norm_cfg=dict(type='GN', num_groups=32),
  32. num_outs=4),
  33. encoder=dict(
  34. num_layers=6,
  35. layer_cfg=dict(
  36. self_attn_cfg=dict(embed_dims=256, num_levels=4,
  37. dropout=0.0), # 0.1 for DeformDETR
  38. ffn_cfg=dict(
  39. embed_dims=256,
  40. feedforward_channels=2048, # 1024 for DeformDETR
  41. ffn_drop=0.0))), # 0.1 for DeformDETR
  42. decoder=dict(
  43. num_layers=6,
  44. return_intermediate=True,
  45. layer_cfg=dict(
  46. self_attn_cfg=dict(embed_dims=256, num_heads=8,
  47. dropout=0.0), # 0.1 for DeformDETR
  48. cross_attn_cfg=dict(embed_dims=256, num_levels=4,
  49. dropout=0.0), # 0.1 for DeformDETR
  50. ffn_cfg=dict(
  51. embed_dims=256,
  52. feedforward_channels=2048, # 1024 for DeformDETR
  53. ffn_drop=0.0)), # 0.1 for DeformDETR
  54. post_norm_cfg=None),
  55. positional_encoding=dict(
  56. num_feats=128,
  57. normalize=True,
  58. offset=0.0, # -0.5 for DeformDETR
  59. temperature=20), # 10000 for DeformDETR
  60. bbox_head=dict(
  61. type='DINOHead',
  62. num_classes=80,
  63. sync_cls_avg_factor=True,
  64. loss_cls=dict(
  65. type='FocalLoss',
  66. use_sigmoid=True,
  67. gamma=2.0,
  68. alpha=0.25,
  69. loss_weight=1.0), # 2.0 in DeformDETR
  70. loss_bbox=dict(type='L1Loss', loss_weight=5.0),
  71. loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
  72. dn_cfg=dict( # TODO: Move to model.train_cfg ?
  73. label_noise_scale=0.5,
  74. box_noise_scale=1.0, # 0.4 for DN-DETR
  75. group_cfg=dict(dynamic=True, num_groups=None,
  76. num_dn_queries=100)), # TODO: half num_dn_queries
  77. # training and testing settings
  78. train_cfg=dict(
  79. assigner=dict(
  80. type='HungarianAssigner',
  81. match_costs=[
  82. dict(type='FocalLossCost', weight=2.0),
  83. dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
  84. dict(type='IoUCost', iou_mode='giou', weight=2.0)
  85. ])),
  86. test_cfg=dict(max_per_img=300)) # 100 for DeformDETR
  87. # train_pipeline, NOTE the img_scale and the Pad's size_divisor is different
  88. # from the default setting in mmdet.
  89. train_pipeline = [
  90. dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}),
  91. dict(type='LoadAnnotations', with_bbox=True),
  92. dict(type='RandomFlip', prob=0.5),
  93. dict(
  94. type='RandomChoice',
  95. transforms=[
  96. [
  97. dict(
  98. type='RandomChoiceResize',
  99. scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
  100. (608, 1333), (640, 1333), (672, 1333), (704, 1333),
  101. (736, 1333), (768, 1333), (800, 1333)],
  102. keep_ratio=True)
  103. ],
  104. [
  105. dict(
  106. type='RandomChoiceResize',
  107. # The radio of all image in train dataset < 7
  108. # follow the original implement
  109. scales=[(400, 4200), (500, 4200), (600, 4200)],
  110. keep_ratio=True),
  111. dict(
  112. type='RandomCrop',
  113. crop_type='absolute_range',
  114. crop_size=(384, 600),
  115. allow_negative_crop=True),
  116. dict(
  117. type='RandomChoiceResize',
  118. scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
  119. (608, 1333), (640, 1333), (672, 1333), (704, 1333),
  120. (736, 1333), (768, 1333), (800, 1333)],
  121. keep_ratio=True)
  122. ]
  123. ]),
  124. dict(type='PackDetInputs')
  125. ]
  126. train_dataloader = dict(
  127. dataset=dict(
  128. filter_cfg=dict(filter_empty_gt=False), pipeline=train_pipeline))
  129. # optimizer
  130. optim_wrapper = dict(
  131. type='OptimWrapper',
  132. optimizer=dict(
  133. type='AdamW',
  134. lr=0.0001, # 0.0002 for DeformDETR
  135. weight_decay=0.0001),
  136. clip_grad=dict(max_norm=0.1, norm_type=2),
  137. paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.1)})
  138. ) # custom_keys contains sampling_offsets and reference_points in DeformDETR # noqa
  139. # learning policy
  140. max_epochs = 12
  141. train_cfg = dict(
  142. type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)
  143. val_cfg = dict(type='ValLoop')
  144. test_cfg = dict(type='TestLoop')
  145. param_scheduler = [
  146. dict(
  147. type='MultiStepLR',
  148. begin=0,
  149. end=max_epochs,
  150. by_epoch=True,
  151. milestones=[11],
  152. gamma=0.1)
  153. ]
  154. # NOTE: `auto_scale_lr` is for automatically scaling LR,
  155. # USER SHOULD NOT CHANGE ITS VALUES.
  156. # base_batch_size = (8 GPUs) x (2 samples per GPU)
  157. auto_scale_lr = dict(base_batch_size=16)