sparseinst_r50_iam_8xb8-ms-270k_coco.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. _base_ = [
  2. 'mmdet::_base_/datasets/coco_instance.py',
  3. 'mmdet::_base_/schedules/schedule_1x.py',
  4. 'mmdet::_base_/default_runtime.py'
  5. ]
  6. custom_imports = dict(
  7. imports=['projects.SparseInst.sparseinst'], allow_failed_imports=False)
  8. model = dict(
  9. type='SparseInst',
  10. data_preprocessor=dict(
  11. type='DetDataPreprocessor',
  12. mean=[123.675, 116.28, 103.53],
  13. std=[58.395, 57.12, 57.375],
  14. bgr_to_rgb=True,
  15. pad_mask=True,
  16. pad_size_divisor=32),
  17. backbone=dict(
  18. type='ResNet',
  19. depth=50,
  20. num_stages=4,
  21. out_indices=(1, 2, 3),
  22. frozen_stages=0,
  23. norm_cfg=dict(type='BN', requires_grad=False),
  24. norm_eval=True,
  25. style='pytorch',
  26. init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
  27. encoder=dict(
  28. type='InstanceContextEncoder',
  29. in_channels=[512, 1024, 2048],
  30. out_channels=256),
  31. decoder=dict(
  32. type='BaseIAMDecoder',
  33. in_channels=256 + 2,
  34. num_classes=80,
  35. ins_dim=256,
  36. ins_conv=4,
  37. mask_dim=256,
  38. mask_conv=4,
  39. kernel_dim=128,
  40. scale_factor=2.0,
  41. output_iam=False,
  42. num_masks=100),
  43. criterion=dict(
  44. type='SparseInstCriterion',
  45. num_classes=80,
  46. assigner=dict(type='SparseInstMatcher', alpha=0.8, beta=0.2),
  47. loss_cls=dict(
  48. type='FocalLoss',
  49. use_sigmoid=True,
  50. alpha=0.25,
  51. gamma=2.0,
  52. reduction='sum',
  53. loss_weight=2.0),
  54. loss_obj=dict(
  55. type='CrossEntropyLoss',
  56. use_sigmoid=True,
  57. reduction='mean',
  58. loss_weight=1.0),
  59. loss_mask=dict(
  60. type='CrossEntropyLoss',
  61. use_sigmoid=True,
  62. reduction='mean',
  63. loss_weight=5.0),
  64. loss_dice=dict(
  65. type='DiceLoss',
  66. use_sigmoid=True,
  67. reduction='sum',
  68. eps=5e-5,
  69. loss_weight=2.0),
  70. ),
  71. test_cfg=dict(score_thr=0.005, mask_thr_binary=0.45))
  72. backend = 'pillow'
  73. train_pipeline = [
  74. dict(
  75. type='LoadImageFromFile',
  76. backend_args={{_base_.backend_args}},
  77. imdecode_backend=backend),
  78. dict(
  79. type='LoadAnnotations',
  80. with_bbox=True,
  81. with_mask=True,
  82. poly2mask=False),
  83. dict(
  84. type='RandomChoiceResize',
  85. scales=[(416, 853), (448, 853), (480, 853), (512, 853), (544, 853),
  86. (576, 853), (608, 853), (640, 853)],
  87. keep_ratio=True,
  88. backend=backend),
  89. dict(type='RandomFlip', prob=0.5),
  90. dict(type='PackDetInputs')
  91. ]
  92. test_pipeline = [
  93. dict(
  94. type='LoadImageFromFile',
  95. backend_args={{_base_.backend_args}},
  96. imdecode_backend=backend),
  97. dict(type='Resize', scale=(640, 853), keep_ratio=True, backend=backend),
  98. dict(
  99. type='PackDetInputs',
  100. meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
  101. 'scale_factor'))
  102. ]
  103. train_dataloader = dict(
  104. batch_size=8,
  105. num_workers=8,
  106. sampler=dict(type='InfiniteSampler'),
  107. dataset=dict(pipeline=train_pipeline))
  108. test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
  109. val_dataloader = test_dataloader
  110. val_evaluator = dict(metric='segm')
  111. test_evaluator = val_evaluator
  112. # optimizer
  113. optim_wrapper = dict(
  114. type='OptimWrapper',
  115. optimizer=dict(_delete_=True, type='AdamW', lr=0.00005, weight_decay=0.05))
  116. train_cfg = dict(
  117. _delete_=True,
  118. type='IterBasedTrainLoop',
  119. max_iters=270000,
  120. val_interval=10000)
  121. # learning rate
  122. param_scheduler = [
  123. dict(
  124. type='MultiStepLR',
  125. begin=0,
  126. end=270000,
  127. by_epoch=False,
  128. milestones=[210000, 250000],
  129. gamma=0.1)
  130. ]
  131. default_hooks = dict(
  132. checkpoint=dict(by_epoch=False, interval=10000, max_keep_ckpts=3))
  133. log_processor = dict(by_epoch=False)
  134. # NOTE: `auto_scale_lr` is for automatically scaling LR,
  135. # USER SHOULD NOT CHANGE ITS VALUES.
  136. # base_batch_size = (8 GPUs) x (8 samples per GPU)
  137. auto_scale_lr = dict(base_batch_size=64, enable=True)