rtmpose-m_8xb64-60e_wflw-256x256.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. _base_ = ['../../../_base_/default_runtime.py']
  2. # runtime
  3. max_epochs = 60
  4. stage2_num_epochs = 10
  5. base_lr = 4e-3
  6. train_cfg = dict(max_epochs=max_epochs, val_interval=1)
  7. randomness = dict(seed=21)
  8. # optimizer
  9. optim_wrapper = dict(
  10. type='OptimWrapper',
  11. optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
  12. paramwise_cfg=dict(
  13. norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
  14. # learning rate
  15. param_scheduler = [
  16. dict(
  17. type='LinearLR',
  18. start_factor=1.0e-5,
  19. by_epoch=False,
  20. begin=0,
  21. end=1000),
  22. dict(
  23. # use cosine lr from 150 to 300 epoch
  24. type='CosineAnnealingLR',
  25. eta_min=base_lr * 0.05,
  26. begin=max_epochs // 2,
  27. end=max_epochs,
  28. T_max=max_epochs // 2,
  29. by_epoch=True,
  30. convert_to_iter_based=True),
  31. ]
  32. # automatically scaling LR based on the actual training batch size
  33. auto_scale_lr = dict(base_batch_size=512)
  34. # codec settings
  35. codec = dict(
  36. type='SimCCLabel',
  37. input_size=(256, 256),
  38. sigma=(5.66, 5.66),
  39. simcc_split_ratio=2.0,
  40. normalize=False,
  41. use_dark=False)
  42. # model settings
  43. model = dict(
  44. type='TopdownPoseEstimator',
  45. data_preprocessor=dict(
  46. type='PoseDataPreprocessor',
  47. mean=[123.675, 116.28, 103.53],
  48. std=[58.395, 57.12, 57.375],
  49. bgr_to_rgb=True),
  50. backbone=dict(
  51. _scope_='mmdet',
  52. type='CSPNeXt',
  53. arch='P5',
  54. expand_ratio=0.5,
  55. deepen_factor=0.67,
  56. widen_factor=0.75,
  57. out_indices=(4, ),
  58. channel_attention=True,
  59. norm_cfg=dict(type='SyncBN'),
  60. act_cfg=dict(type='SiLU'),
  61. init_cfg=dict(
  62. type='Pretrained',
  63. prefix='backbone.',
  64. checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
  65. 'rtmposev1/cspnext-m_udp-aic-coco_210e-256x192-f2f7d6f6_20230130.pth' # noqa
  66. )),
  67. head=dict(
  68. type='RTMCCHead',
  69. in_channels=768,
  70. out_channels=98,
  71. input_size=codec['input_size'],
  72. in_featuremap_size=(8, 8),
  73. simcc_split_ratio=codec['simcc_split_ratio'],
  74. final_layer_kernel_size=7,
  75. gau_cfg=dict(
  76. hidden_dims=256,
  77. s=128,
  78. expansion_factor=2,
  79. dropout_rate=0.,
  80. drop_path=0.,
  81. act_fn='SiLU',
  82. use_rel_bias=False,
  83. pos_enc=False),
  84. loss=dict(
  85. type='KLDiscretLoss',
  86. use_target_weight=True,
  87. beta=10.,
  88. label_softmax=True),
  89. decoder=codec),
  90. test_cfg=dict(flip_test=True, ))
  91. # base dataset settings
  92. dataset_type = 'WFLWDataset'
  93. data_mode = 'topdown'
  94. data_root = 'data/wflw/'
  95. backend_args = dict(backend='local')
  96. # backend_args = dict(
  97. # backend='petrel',
  98. # path_mapping=dict({
  99. # f'{data_root}': 's3://openmmlab/datasets/pose/WFLW/',
  100. # f'{data_root}': 's3://openmmlab/datasets/pose/WFLW/'
  101. # }))
  102. # pipelines
  103. train_pipeline = [
  104. dict(type='LoadImage', backend_args=backend_args),
  105. dict(type='GetBBoxCenterScale'),
  106. dict(type='RandomFlip', direction='horizontal'),
  107. # dict(type='RandomHalfBody'),
  108. dict(
  109. type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80),
  110. dict(type='TopdownAffine', input_size=codec['input_size']),
  111. dict(type='mmdet.YOLOXHSVRandomAug'),
  112. dict(
  113. type='Albumentation',
  114. transforms=[
  115. dict(type='Blur', p=0.1),
  116. dict(type='MedianBlur', p=0.1),
  117. dict(
  118. type='CoarseDropout',
  119. max_holes=1,
  120. max_height=0.4,
  121. max_width=0.4,
  122. min_holes=1,
  123. min_height=0.2,
  124. min_width=0.2,
  125. p=1.0),
  126. ]),
  127. dict(type='GenerateTarget', encoder=codec),
  128. dict(type='PackPoseInputs')
  129. ]
  130. val_pipeline = [
  131. dict(type='LoadImage', backend_args=backend_args),
  132. dict(type='GetBBoxCenterScale'),
  133. dict(type='TopdownAffine', input_size=codec['input_size']),
  134. dict(type='PackPoseInputs')
  135. ]
  136. train_pipeline_stage2 = [
  137. dict(type='LoadImage', backend_args=backend_args),
  138. dict(type='GetBBoxCenterScale'),
  139. dict(type='RandomFlip', direction='horizontal'),
  140. # dict(type='RandomHalfBody'),
  141. dict(
  142. type='RandomBBoxTransform',
  143. shift_factor=0.,
  144. scale_factor=[0.75, 1.25],
  145. rotate_factor=60),
  146. dict(type='TopdownAffine', input_size=codec['input_size']),
  147. dict(type='mmdet.YOLOXHSVRandomAug'),
  148. dict(
  149. type='Albumentation',
  150. transforms=[
  151. dict(type='Blur', p=0.1),
  152. dict(type='MedianBlur', p=0.1),
  153. dict(
  154. type='CoarseDropout',
  155. max_holes=1,
  156. max_height=0.4,
  157. max_width=0.4,
  158. min_holes=1,
  159. min_height=0.2,
  160. min_width=0.2,
  161. p=0.5),
  162. ]),
  163. dict(type='GenerateTarget', encoder=codec),
  164. dict(type='PackPoseInputs')
  165. ]
  166. # data loaders
  167. train_dataloader = dict(
  168. batch_size=64,
  169. num_workers=10,
  170. persistent_workers=True,
  171. sampler=dict(type='DefaultSampler', shuffle=True),
  172. dataset=dict(
  173. type=dataset_type,
  174. data_root=data_root,
  175. data_mode=data_mode,
  176. ann_file='annotations/face_landmarks_wflw_train.json',
  177. data_prefix=dict(img='images/'),
  178. pipeline=train_pipeline,
  179. ))
  180. val_dataloader = dict(
  181. batch_size=32,
  182. num_workers=10,
  183. persistent_workers=True,
  184. drop_last=False,
  185. sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
  186. dataset=dict(
  187. type=dataset_type,
  188. data_root=data_root,
  189. data_mode=data_mode,
  190. ann_file='annotations/face_landmarks_wflw_test.json',
  191. data_prefix=dict(img='images/'),
  192. test_mode=True,
  193. pipeline=val_pipeline,
  194. ))
  195. test_dataloader = val_dataloader
  196. # hooks
  197. default_hooks = dict(
  198. checkpoint=dict(
  199. save_best='NME', rule='less', max_keep_ckpts=1, interval=1))
  200. custom_hooks = [
  201. dict(
  202. type='EMAHook',
  203. ema_type='ExpMomentumEMA',
  204. momentum=0.0002,
  205. update_buffers=True,
  206. priority=49),
  207. dict(
  208. type='mmdet.PipelineSwitchHook',
  209. switch_epoch=max_epochs - stage2_num_epochs,
  210. switch_pipeline=train_pipeline_stage2)
  211. ]
  212. # evaluators
  213. val_evaluator = dict(
  214. type='NME',
  215. norm_mode='keypoint_distance',
  216. )
  217. test_evaluator = val_evaluator