detr_r50_8xb2-150e_coco.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. _base_ = [
  2. '../_base_/datasets/coco_detection.py', '../_base_/default_runtime.py'
  3. ]
  4. model = dict(
  5. type='DETR',
  6. num_queries=100,
  7. data_preprocessor=dict(
  8. type='DetDataPreprocessor',
  9. mean=[123.675, 116.28, 103.53],
  10. std=[58.395, 57.12, 57.375],
  11. bgr_to_rgb=True,
  12. pad_size_divisor=1),
  13. backbone=dict(
  14. type='ResNet',
  15. depth=50,
  16. num_stages=4,
  17. out_indices=(3, ),
  18. frozen_stages=1,
  19. norm_cfg=dict(type='BN', requires_grad=False),
  20. norm_eval=True,
  21. style='pytorch',
  22. init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
  23. neck=dict(
  24. type='ChannelMapper',
  25. in_channels=[2048],
  26. kernel_size=1,
  27. out_channels=256,
  28. act_cfg=None,
  29. norm_cfg=None,
  30. num_outs=1),
  31. encoder=dict( # DetrTransformerEncoder
  32. num_layers=6,
  33. layer_cfg=dict( # DetrTransformerEncoderLayer
  34. self_attn_cfg=dict( # MultiheadAttention
  35. embed_dims=256,
  36. num_heads=8,
  37. dropout=0.1,
  38. batch_first=True),
  39. ffn_cfg=dict(
  40. embed_dims=256,
  41. feedforward_channels=2048,
  42. num_fcs=2,
  43. ffn_drop=0.1,
  44. act_cfg=dict(type='ReLU', inplace=True)))),
  45. decoder=dict( # DetrTransformerDecoder
  46. num_layers=6,
  47. layer_cfg=dict( # DetrTransformerDecoderLayer
  48. self_attn_cfg=dict( # MultiheadAttention
  49. embed_dims=256,
  50. num_heads=8,
  51. dropout=0.1,
  52. batch_first=True),
  53. cross_attn_cfg=dict( # MultiheadAttention
  54. embed_dims=256,
  55. num_heads=8,
  56. dropout=0.1,
  57. batch_first=True),
  58. ffn_cfg=dict(
  59. embed_dims=256,
  60. feedforward_channels=2048,
  61. num_fcs=2,
  62. ffn_drop=0.1,
  63. act_cfg=dict(type='ReLU', inplace=True))),
  64. return_intermediate=True),
  65. positional_encoding=dict(num_feats=128, normalize=True),
  66. bbox_head=dict(
  67. type='DETRHead',
  68. num_classes=80,
  69. embed_dims=256,
  70. loss_cls=dict(
  71. type='CrossEntropyLoss',
  72. bg_cls_weight=0.1,
  73. use_sigmoid=False,
  74. loss_weight=1.0,
  75. class_weight=1.0),
  76. loss_bbox=dict(type='L1Loss', loss_weight=5.0),
  77. loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
  78. # training and testing settings
  79. train_cfg=dict(
  80. assigner=dict(
  81. type='HungarianAssigner',
  82. match_costs=[
  83. dict(type='ClassificationCost', weight=1.),
  84. dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
  85. dict(type='IoUCost', iou_mode='giou', weight=2.0)
  86. ])),
  87. test_cfg=dict(max_per_img=100))
  88. # train_pipeline, NOTE the img_scale and the Pad's size_divisor is different
  89. # from the default setting in mmdet.
  90. train_pipeline = [
  91. dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}),
  92. dict(type='LoadAnnotations', with_bbox=True),
  93. dict(type='RandomFlip', prob=0.5),
  94. dict(
  95. type='RandomChoice',
  96. transforms=[[
  97. dict(
  98. type='RandomChoiceResize',
  99. scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
  100. (608, 1333), (640, 1333), (672, 1333), (704, 1333),
  101. (736, 1333), (768, 1333), (800, 1333)],
  102. keep_ratio=True)
  103. ],
  104. [
  105. dict(
  106. type='RandomChoiceResize',
  107. scales=[(400, 1333), (500, 1333), (600, 1333)],
  108. keep_ratio=True),
  109. dict(
  110. type='RandomCrop',
  111. crop_type='absolute_range',
  112. crop_size=(384, 600),
  113. allow_negative_crop=True),
  114. dict(
  115. type='RandomChoiceResize',
  116. scales=[(480, 1333), (512, 1333), (544, 1333),
  117. (576, 1333), (608, 1333), (640, 1333),
  118. (672, 1333), (704, 1333), (736, 1333),
  119. (768, 1333), (800, 1333)],
  120. keep_ratio=True)
  121. ]]),
  122. dict(type='PackDetInputs')
  123. ]
  124. train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
  125. # optimizer
  126. optim_wrapper = dict(
  127. type='OptimWrapper',
  128. optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.0001),
  129. clip_grad=dict(max_norm=0.1, norm_type=2),
  130. paramwise_cfg=dict(
  131. custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}))
  132. # learning policy
  133. max_epochs = 150
  134. train_cfg = dict(
  135. type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)
  136. val_cfg = dict(type='ValLoop')
  137. test_cfg = dict(type='TestLoop')
  138. param_scheduler = [
  139. dict(
  140. type='MultiStepLR',
  141. begin=0,
  142. end=max_epochs,
  143. by_epoch=True,
  144. milestones=[100],
  145. gamma=0.1)
  146. ]
  147. # NOTE: `auto_scale_lr` is for automatically scaling LR,
  148. # USER SHOULD NOT CHANGE ITS VALUES.
  149. # base_batch_size = (8 GPUs) x (2 samples per GPU)
  150. auto_scale_lr = dict(base_batch_size=16)