yolact_r50_1xb8-55e_coco.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. _base_ = [
  2. '../_base_/datasets/coco_instance.py', '../_base_/default_runtime.py'
  3. ]
  4. img_norm_cfg = dict(
  5. mean=[123.68, 116.78, 103.94], std=[58.40, 57.12, 57.38], to_rgb=True)
  6. # model settings
  7. input_size = 550
  8. model = dict(
  9. type='YOLACT',
  10. data_preprocessor=dict(
  11. type='DetDataPreprocessor',
  12. mean=img_norm_cfg['mean'],
  13. std=img_norm_cfg['std'],
  14. bgr_to_rgb=img_norm_cfg['to_rgb'],
  15. pad_mask=True),
  16. backbone=dict(
  17. type='ResNet',
  18. depth=50,
  19. num_stages=4,
  20. out_indices=(0, 1, 2, 3),
  21. frozen_stages=-1, # do not freeze stem
  22. norm_cfg=dict(type='BN', requires_grad=True),
  23. norm_eval=False, # update the statistics of bn
  24. zero_init_residual=False,
  25. style='pytorch',
  26. init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
  27. neck=dict(
  28. type='FPN',
  29. in_channels=[256, 512, 1024, 2048],
  30. out_channels=256,
  31. start_level=1,
  32. add_extra_convs='on_input',
  33. num_outs=5,
  34. upsample_cfg=dict(mode='bilinear')),
  35. bbox_head=dict(
  36. type='YOLACTHead',
  37. num_classes=80,
  38. in_channels=256,
  39. feat_channels=256,
  40. anchor_generator=dict(
  41. type='AnchorGenerator',
  42. octave_base_scale=3,
  43. scales_per_octave=1,
  44. base_sizes=[8, 16, 32, 64, 128],
  45. ratios=[0.5, 1.0, 2.0],
  46. strides=[550.0 / x for x in [69, 35, 18, 9, 5]],
  47. centers=[(550 * 0.5 / x, 550 * 0.5 / x)
  48. for x in [69, 35, 18, 9, 5]]),
  49. bbox_coder=dict(
  50. type='DeltaXYWHBBoxCoder',
  51. target_means=[.0, .0, .0, .0],
  52. target_stds=[0.1, 0.1, 0.2, 0.2]),
  53. loss_cls=dict(
  54. type='CrossEntropyLoss',
  55. use_sigmoid=False,
  56. reduction='none',
  57. loss_weight=1.0),
  58. loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.5),
  59. num_head_convs=1,
  60. num_protos=32,
  61. use_ohem=True),
  62. mask_head=dict(
  63. type='YOLACTProtonet',
  64. in_channels=256,
  65. num_protos=32,
  66. num_classes=80,
  67. max_masks_to_train=100,
  68. loss_mask_weight=6.125,
  69. with_seg_branch=True,
  70. loss_segm=dict(
  71. type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
  72. # training and testing settings
  73. train_cfg=dict(
  74. assigner=dict(
  75. type='MaxIoUAssigner',
  76. pos_iou_thr=0.5,
  77. neg_iou_thr=0.4,
  78. min_pos_iou=0.,
  79. ignore_iof_thr=-1,
  80. gt_max_assign_all=False),
  81. sampler=dict(type='PseudoSampler'), # YOLACT should use PseudoSampler
  82. # smoothl1_beta=1.,
  83. allowed_border=-1,
  84. pos_weight=-1,
  85. neg_pos_ratio=3,
  86. debug=False),
  87. test_cfg=dict(
  88. nms_pre=1000,
  89. min_bbox_size=0,
  90. score_thr=0.05,
  91. mask_thr=0.5,
  92. iou_thr=0.5,
  93. top_k=200,
  94. max_per_img=100,
  95. mask_thr_binary=0.5))
  96. # dataset settings
  97. train_pipeline = [
  98. dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}),
  99. dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
  100. dict(type='FilterAnnotations', min_gt_bbox_wh=(4.0, 4.0)),
  101. dict(
  102. type='Expand',
  103. mean=img_norm_cfg['mean'],
  104. to_rgb=img_norm_cfg['to_rgb'],
  105. ratio_range=(1, 4)),
  106. dict(
  107. type='MinIoURandomCrop',
  108. min_ious=(0.1, 0.3, 0.5, 0.7, 0.9),
  109. min_crop_size=0.3),
  110. dict(type='Resize', scale=(input_size, input_size), keep_ratio=False),
  111. dict(type='RandomFlip', prob=0.5),
  112. dict(
  113. type='PhotoMetricDistortion',
  114. brightness_delta=32,
  115. contrast_range=(0.5, 1.5),
  116. saturation_range=(0.5, 1.5),
  117. hue_delta=18),
  118. dict(type='PackDetInputs')
  119. ]
  120. test_pipeline = [
  121. dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}),
  122. dict(type='Resize', scale=(input_size, input_size), keep_ratio=False),
  123. dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
  124. dict(
  125. type='PackDetInputs',
  126. meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
  127. 'scale_factor'))
  128. ]
  129. train_dataloader = dict(
  130. batch_size=8,
  131. num_workers=4,
  132. batch_sampler=None,
  133. dataset=dict(pipeline=train_pipeline))
  134. val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
  135. test_dataloader = val_dataloader
  136. max_epochs = 55
  137. # training schedule for 55e
  138. train_cfg = dict(
  139. type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)
  140. val_cfg = dict(type='ValLoop')
  141. test_cfg = dict(type='TestLoop')
  142. # learning rate
  143. param_scheduler = [
  144. dict(type='LinearLR', start_factor=0.1, by_epoch=False, begin=0, end=500),
  145. dict(
  146. type='MultiStepLR',
  147. begin=0,
  148. end=max_epochs,
  149. by_epoch=True,
  150. milestones=[20, 42, 49, 52],
  151. gamma=0.1)
  152. ]
  153. # optimizer
  154. optim_wrapper = dict(
  155. type='OptimWrapper',
  156. optimizer=dict(type='SGD', lr=1e-3, momentum=0.9, weight_decay=5e-4))
  157. custom_hooks = [
  158. dict(type='CheckInvalidLossHook', interval=50, priority='VERY_LOW')
  159. ]
  160. env_cfg = dict(cudnn_benchmark=True)
  161. # NOTE: `auto_scale_lr` is for automatically scaling LR,
  162. # USER SHOULD NOT CHANGE ITS VALUES.
  163. # base_batch_size = (1 GPUs) x (8 samples per GPU)
  164. auto_scale_lr = dict(base_batch_size=8)