yolox_s_8xb8-300e_coco.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. _base_ = [
  2. '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py',
  3. './yolox_tta.py'
  4. ]
  5. img_scale = (640, 640) # width, height
  6. # model settings
  7. model = dict(
  8. type='YOLOX',
  9. data_preprocessor=dict(
  10. type='DetDataPreprocessor',
  11. pad_size_divisor=32,
  12. batch_augments=[
  13. dict(
  14. type='BatchSyncRandomResize',
  15. random_size_range=(480, 800),
  16. size_divisor=32,
  17. interval=10)
  18. ]),
  19. backbone=dict(
  20. type='CSPDarknet',
  21. deepen_factor=0.33,
  22. widen_factor=0.5,
  23. out_indices=(2, 3, 4),
  24. use_depthwise=False,
  25. spp_kernal_sizes=(5, 9, 13),
  26. norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
  27. act_cfg=dict(type='Swish'),
  28. ),
  29. neck=dict(
  30. type='YOLOXPAFPN',
  31. in_channels=[128, 256, 512],
  32. out_channels=128,
  33. num_csp_blocks=1,
  34. use_depthwise=False,
  35. upsample_cfg=dict(scale_factor=2, mode='nearest'),
  36. norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
  37. act_cfg=dict(type='Swish')),
  38. bbox_head=dict(
  39. type='YOLOXHead',
  40. num_classes=80,
  41. in_channels=128,
  42. feat_channels=128,
  43. stacked_convs=2,
  44. strides=(8, 16, 32),
  45. use_depthwise=False,
  46. norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
  47. act_cfg=dict(type='Swish'),
  48. loss_cls=dict(
  49. type='CrossEntropyLoss',
  50. use_sigmoid=True,
  51. reduction='sum',
  52. loss_weight=1.0),
  53. loss_bbox=dict(
  54. type='IoULoss',
  55. mode='square',
  56. eps=1e-16,
  57. reduction='sum',
  58. loss_weight=5.0),
  59. loss_obj=dict(
  60. type='CrossEntropyLoss',
  61. use_sigmoid=True,
  62. reduction='sum',
  63. loss_weight=1.0),
  64. loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0)),
  65. train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
  66. # In order to align the source code, the threshold of the val phase is
  67. # 0.01, and the threshold of the test phase is 0.001.
  68. test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))
  69. # dataset settings
  70. data_root = 'data/coco/'
  71. dataset_type = 'CocoDataset'
  72. # Example to use different file client
  73. # Method 1: simply set the data root and let the file I/O module
  74. # automatically infer from prefix (not support LMDB and Memcache yet)
  75. # data_root = 's3://openmmlab/datasets/detection/coco/'
  76. # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6
  77. # backend_args = dict(
  78. # backend='petrel',
  79. # path_mapping=dict({
  80. # './data/': 's3://openmmlab/datasets/detection/',
  81. # 'data/': 's3://openmmlab/datasets/detection/'
  82. # }))
  83. backend_args = None
  84. train_pipeline = [
  85. dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
  86. dict(
  87. type='RandomAffine',
  88. scaling_ratio_range=(0.1, 2),
  89. # img_scale is (width, height)
  90. border=(-img_scale[0] // 2, -img_scale[1] // 2)),
  91. dict(
  92. type='MixUp',
  93. img_scale=img_scale,
  94. ratio_range=(0.8, 1.6),
  95. pad_val=114.0),
  96. dict(type='YOLOXHSVRandomAug'),
  97. dict(type='RandomFlip', prob=0.5),
  98. # According to the official implementation, multi-scale
  99. # training is not considered here but in the
  100. # 'mmdet/models/detectors/yolox.py'.
  101. # Resize and Pad are for the last 15 epochs when Mosaic,
  102. # RandomAffine, and MixUp are closed by YOLOXModeSwitchHook.
  103. dict(type='Resize', scale=img_scale, keep_ratio=True),
  104. dict(
  105. type='Pad',
  106. pad_to_square=True,
  107. # If the image is three-channel, the pad value needs
  108. # to be set separately for each channel.
  109. pad_val=dict(img=(114.0, 114.0, 114.0))),
  110. dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
  111. dict(type='PackDetInputs')
  112. ]
  113. train_dataset = dict(
  114. # use MultiImageMixDataset wrapper to support mosaic and mixup
  115. type='MultiImageMixDataset',
  116. dataset=dict(
  117. type=dataset_type,
  118. data_root=data_root,
  119. ann_file='annotations/instances_train2017.json',
  120. data_prefix=dict(img='train2017/'),
  121. pipeline=[
  122. dict(type='LoadImageFromFile', backend_args=backend_args),
  123. dict(type='LoadAnnotations', with_bbox=True)
  124. ],
  125. filter_cfg=dict(filter_empty_gt=False, min_size=32),
  126. backend_args=backend_args),
  127. pipeline=train_pipeline)
  128. test_pipeline = [
  129. dict(type='LoadImageFromFile', backend_args=backend_args),
  130. dict(type='Resize', scale=img_scale, keep_ratio=True),
  131. dict(
  132. type='Pad',
  133. pad_to_square=True,
  134. pad_val=dict(img=(114.0, 114.0, 114.0))),
  135. dict(type='LoadAnnotations', with_bbox=True),
  136. dict(
  137. type='PackDetInputs',
  138. meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
  139. 'scale_factor'))
  140. ]
  141. train_dataloader = dict(
  142. batch_size=8,
  143. num_workers=4,
  144. persistent_workers=True,
  145. sampler=dict(type='DefaultSampler', shuffle=True),
  146. dataset=train_dataset)
  147. val_dataloader = dict(
  148. batch_size=8,
  149. num_workers=4,
  150. persistent_workers=True,
  151. drop_last=False,
  152. sampler=dict(type='DefaultSampler', shuffle=False),
  153. dataset=dict(
  154. type=dataset_type,
  155. data_root=data_root,
  156. ann_file='annotations/instances_val2017.json',
  157. data_prefix=dict(img='val2017/'),
  158. test_mode=True,
  159. pipeline=test_pipeline,
  160. backend_args=backend_args))
  161. test_dataloader = val_dataloader
  162. val_evaluator = dict(
  163. type='CocoMetric',
  164. ann_file=data_root + 'annotations/instances_val2017.json',
  165. metric='bbox',
  166. backend_args=backend_args)
  167. test_evaluator = val_evaluator
  168. # training settings
  169. max_epochs = 300
  170. num_last_epochs = 15
  171. interval = 10
  172. train_cfg = dict(max_epochs=max_epochs, val_interval=interval)
  173. # optimizer
  174. # default 8 gpu
  175. base_lr = 0.01
  176. optim_wrapper = dict(
  177. type='OptimWrapper',
  178. optimizer=dict(
  179. type='SGD', lr=base_lr, momentum=0.9, weight_decay=5e-4,
  180. nesterov=True),
  181. paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
  182. # learning rate
  183. param_scheduler = [
  184. dict(
  185. # use quadratic formula to warm up 5 epochs
  186. # and lr is updated by iteration
  187. # TODO: fix default scope in get function
  188. type='mmdet.QuadraticWarmupLR',
  189. by_epoch=True,
  190. begin=0,
  191. end=5,
  192. convert_to_iter_based=True),
  193. dict(
  194. # use cosine lr from 5 to 285 epoch
  195. type='CosineAnnealingLR',
  196. eta_min=base_lr * 0.05,
  197. begin=5,
  198. T_max=max_epochs - num_last_epochs,
  199. end=max_epochs - num_last_epochs,
  200. by_epoch=True,
  201. convert_to_iter_based=True),
  202. dict(
  203. # use fixed lr during last 15 epochs
  204. type='ConstantLR',
  205. by_epoch=True,
  206. factor=1,
  207. begin=max_epochs - num_last_epochs,
  208. end=max_epochs,
  209. )
  210. ]
  211. default_hooks = dict(
  212. checkpoint=dict(
  213. interval=interval,
  214. max_keep_ckpts=3 # only keep latest 3 checkpoints
  215. ))
  216. custom_hooks = [
  217. dict(
  218. type='YOLOXModeSwitchHook',
  219. num_last_epochs=num_last_epochs,
  220. priority=48),
  221. dict(type='SyncNormHook', priority=48),
  222. dict(
  223. type='EMAHook',
  224. ema_type='ExpMomentumEMA',
  225. momentum=0.0001,
  226. update_buffers=True,
  227. priority=49)
  228. ]
  229. # NOTE: `auto_scale_lr` is for automatically scaling LR,
  230. # USER SHOULD NOT CHANGE ITS VALUES.
  231. # base_batch_size = (8 GPUs) x (8 samples per GPU)
  232. auto_scale_lr = dict(base_batch_size=64)