rtmpose-m_8xb256-210e_hand5-256x256.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. _base_ = ['../../../_base_/default_runtime.py']
  2. # coco-hand onehand10k freihand2d rhd2d halpehand
  3. # runtime
  4. max_epochs = 210
  5. stage2_num_epochs = 10
  6. base_lr = 4e-3
  7. train_cfg = dict(max_epochs=max_epochs, val_interval=10)
  8. randomness = dict(seed=21)
  9. # optimizer
  10. optim_wrapper = dict(
  11. type='OptimWrapper',
  12. optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
  13. paramwise_cfg=dict(
  14. norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
  15. # learning rate
  16. param_scheduler = [
  17. dict(
  18. type='LinearLR',
  19. start_factor=1.0e-5,
  20. by_epoch=False,
  21. begin=0,
  22. end=1000),
  23. dict(
  24. # use cosine lr from 150 to 300 epoch
  25. type='CosineAnnealingLR',
  26. eta_min=base_lr * 0.05,
  27. begin=max_epochs // 2,
  28. end=max_epochs,
  29. T_max=max_epochs // 2,
  30. by_epoch=True,
  31. convert_to_iter_based=True),
  32. ]
  33. # automatically scaling LR based on the actual training batch size
  34. auto_scale_lr = dict(base_batch_size=256)
  35. # codec settings
  36. codec = dict(
  37. type='SimCCLabel',
  38. input_size=(256, 256),
  39. sigma=(5.66, 5.66),
  40. simcc_split_ratio=2.0,
  41. normalize=False,
  42. use_dark=False)
  43. # model settings
  44. model = dict(
  45. type='TopdownPoseEstimator',
  46. data_preprocessor=dict(
  47. type='PoseDataPreprocessor',
  48. mean=[123.675, 116.28, 103.53],
  49. std=[58.395, 57.12, 57.375],
  50. bgr_to_rgb=True),
  51. backbone=dict(
  52. _scope_='mmdet',
  53. type='CSPNeXt',
  54. arch='P5',
  55. expand_ratio=0.5,
  56. deepen_factor=0.67,
  57. widen_factor=0.75,
  58. out_indices=(4, ),
  59. channel_attention=True,
  60. norm_cfg=dict(type='SyncBN'),
  61. act_cfg=dict(type='SiLU'),
  62. init_cfg=dict(
  63. type='Pretrained',
  64. prefix='backbone.',
  65. checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
  66. 'rtmpose/cspnext-m_udp-aic-coco_210e-256x192-f2f7d6f6_20230130.pth' # noqa
  67. )),
  68. head=dict(
  69. type='RTMCCHead',
  70. in_channels=768,
  71. out_channels=21,
  72. input_size=codec['input_size'],
  73. in_featuremap_size=(8, 8),
  74. simcc_split_ratio=codec['simcc_split_ratio'],
  75. final_layer_kernel_size=7,
  76. gau_cfg=dict(
  77. hidden_dims=256,
  78. s=128,
  79. expansion_factor=2,
  80. dropout_rate=0.,
  81. drop_path=0.,
  82. act_fn='SiLU',
  83. use_rel_bias=False,
  84. pos_enc=False),
  85. loss=dict(
  86. type='KLDiscretLoss',
  87. use_target_weight=True,
  88. beta=10.,
  89. label_softmax=True),
  90. decoder=codec),
  91. test_cfg=dict(flip_test=True, ))
  92. # base dataset settings
  93. dataset_type = 'CocoWholeBodyHandDataset'
  94. data_mode = 'topdown'
  95. data_root = 'data/'
  96. backend_args = dict(backend='local')
  97. # pipelines
  98. train_pipeline = [
  99. dict(type='LoadImage', backend_args=backend_args),
  100. dict(type='GetBBoxCenterScale'),
  101. # dict(type='RandomHalfBody'),
  102. dict(
  103. type='RandomBBoxTransform', scale_factor=[0.5, 1.5],
  104. rotate_factor=180),
  105. dict(type='RandomFlip', direction='horizontal'),
  106. dict(type='TopdownAffine', input_size=codec['input_size']),
  107. dict(type='mmdet.YOLOXHSVRandomAug'),
  108. dict(
  109. type='Albumentation',
  110. transforms=[
  111. dict(type='Blur', p=0.1),
  112. dict(type='MedianBlur', p=0.1),
  113. dict(
  114. type='CoarseDropout',
  115. max_holes=1,
  116. max_height=0.4,
  117. max_width=0.4,
  118. min_holes=1,
  119. min_height=0.2,
  120. min_width=0.2,
  121. p=1.0),
  122. ]),
  123. dict(type='GenerateTarget', encoder=codec),
  124. dict(type='PackPoseInputs')
  125. ]
  126. val_pipeline = [
  127. dict(type='LoadImage', backend_args=backend_args),
  128. dict(type='GetBBoxCenterScale'),
  129. dict(type='TopdownAffine', input_size=codec['input_size']),
  130. dict(type='PackPoseInputs')
  131. ]
  132. train_pipeline_stage2 = [
  133. dict(type='LoadImage', backend_args=backend_args),
  134. dict(type='GetBBoxCenterScale'),
  135. # dict(type='RandomHalfBody'),
  136. dict(
  137. type='RandomBBoxTransform',
  138. shift_factor=0.,
  139. scale_factor=[0.75, 1.25],
  140. rotate_factor=180),
  141. dict(type='RandomFlip', direction='horizontal'),
  142. dict(type='TopdownAffine', input_size=codec['input_size']),
  143. dict(type='mmdet.YOLOXHSVRandomAug'),
  144. dict(
  145. type='Albumentation',
  146. transforms=[
  147. dict(type='Blur', p=0.2),
  148. dict(type='MedianBlur', p=0.2),
  149. dict(
  150. type='CoarseDropout',
  151. max_holes=1,
  152. max_height=0.4,
  153. max_width=0.4,
  154. min_holes=1,
  155. min_height=0.2,
  156. min_width=0.2,
  157. p=0.5),
  158. ]),
  159. dict(type='GenerateTarget', encoder=codec),
  160. dict(type='PackPoseInputs')
  161. ]
  162. # train datasets
  163. dataset_coco = dict(
  164. type=dataset_type,
  165. data_root=data_root,
  166. data_mode=data_mode,
  167. ann_file='coco/annotations/coco_wholebody_train_v1.0.json',
  168. data_prefix=dict(img='detection/coco/train2017/'),
  169. pipeline=[],
  170. )
  171. dataset_onehand10k = dict(
  172. type='OneHand10KDataset',
  173. data_root=data_root,
  174. data_mode=data_mode,
  175. ann_file='onehand10k/annotations/onehand10k_train.json',
  176. data_prefix=dict(img='pose/OneHand10K/'),
  177. pipeline=[],
  178. )
  179. dataset_freihand = dict(
  180. type='FreiHandDataset',
  181. data_root=data_root,
  182. data_mode=data_mode,
  183. ann_file='freihand/annotations/freihand_train.json',
  184. data_prefix=dict(img='pose/FreiHand/'),
  185. pipeline=[],
  186. )
  187. dataset_rhd = dict(
  188. type='Rhd2DDataset',
  189. data_root=data_root,
  190. data_mode=data_mode,
  191. ann_file='rhd/annotations/rhd_train.json',
  192. data_prefix=dict(img='pose/RHD/'),
  193. pipeline=[
  194. dict(
  195. type='KeypointConverter',
  196. num_keypoints=21,
  197. mapping=[
  198. (0, 0),
  199. (1, 4),
  200. (2, 3),
  201. (3, 2),
  202. (4, 1),
  203. (5, 8),
  204. (6, 7),
  205. (7, 6),
  206. (8, 5),
  207. (9, 12),
  208. (10, 11),
  209. (11, 10),
  210. (12, 9),
  211. (13, 16),
  212. (14, 15),
  213. (15, 14),
  214. (16, 13),
  215. (17, 20),
  216. (18, 19),
  217. (19, 18),
  218. (20, 17),
  219. ])
  220. ],
  221. )
  222. dataset_halpehand = dict(
  223. type='HalpeHandDataset',
  224. data_root=data_root,
  225. data_mode=data_mode,
  226. ann_file='halpe/annotations/halpe_train_v1.json',
  227. data_prefix=dict(img='pose/Halpe/hico_20160224_det/images/train2015/'),
  228. pipeline=[],
  229. )
  230. # data loaders
  231. train_dataloader = dict(
  232. batch_size=256,
  233. num_workers=10,
  234. persistent_workers=True,
  235. sampler=dict(type='DefaultSampler', shuffle=True),
  236. dataset=dict(
  237. type='CombinedDataset',
  238. metainfo=dict(
  239. from_file='configs/_base_/datasets/coco_wholebody_hand.py'),
  240. datasets=[
  241. dataset_coco, dataset_onehand10k, dataset_freihand, dataset_rhd,
  242. dataset_halpehand
  243. ],
  244. pipeline=train_pipeline,
  245. test_mode=False,
  246. ))
  247. # test datasets
  248. val_coco = dict(
  249. type=dataset_type,
  250. data_root=data_root,
  251. data_mode=data_mode,
  252. ann_file='coco/annotations/coco_wholebody_val_v1.0.json',
  253. data_prefix=dict(img='detection/coco/val2017/'),
  254. pipeline=[],
  255. )
  256. val_onehand10k = dict(
  257. type='OneHand10KDataset',
  258. data_root=data_root,
  259. data_mode=data_mode,
  260. ann_file='onehand10k/annotations/onehand10k_test.json',
  261. data_prefix=dict(img='pose/OneHand10K/'),
  262. pipeline=[],
  263. )
  264. val_freihand = dict(
  265. type='FreiHandDataset',
  266. data_root=data_root,
  267. data_mode=data_mode,
  268. ann_file='freihand/annotations/freihand_test.json',
  269. data_prefix=dict(img='pose/FreiHand/'),
  270. pipeline=[],
  271. )
  272. val_rhd = dict(
  273. type='Rhd2DDataset',
  274. data_root=data_root,
  275. data_mode=data_mode,
  276. ann_file='rhd/annotations/rhd_test.json',
  277. data_prefix=dict(img='pose/RHD/'),
  278. pipeline=[
  279. dict(
  280. type='KeypointConverter',
  281. num_keypoints=21,
  282. mapping=[
  283. (0, 0),
  284. (1, 4),
  285. (2, 3),
  286. (3, 2),
  287. (4, 1),
  288. (5, 8),
  289. (6, 7),
  290. (7, 6),
  291. (8, 5),
  292. (9, 12),
  293. (10, 11),
  294. (11, 10),
  295. (12, 9),
  296. (13, 16),
  297. (14, 15),
  298. (15, 14),
  299. (16, 13),
  300. (17, 20),
  301. (18, 19),
  302. (19, 18),
  303. (20, 17),
  304. ])
  305. ],
  306. )
  307. val_halpehand = dict(
  308. type='HalpeHandDataset',
  309. data_root=data_root,
  310. data_mode=data_mode,
  311. ann_file='halpe/annotations/halpe_val_v1.json',
  312. data_prefix=dict(img='detection/coco/val2017/'),
  313. pipeline=[],
  314. )
  315. test_dataloader = dict(
  316. batch_size=32,
  317. num_workers=10,
  318. persistent_workers=True,
  319. drop_last=False,
  320. sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
  321. dataset=dict(
  322. type='CombinedDataset',
  323. metainfo=dict(
  324. from_file='configs/_base_/datasets/coco_wholebody_hand.py'),
  325. datasets=[
  326. val_coco, val_onehand10k, val_freihand, val_rhd, val_halpehand
  327. ],
  328. pipeline=val_pipeline,
  329. test_mode=True,
  330. ))
  331. val_dataloader = test_dataloader
  332. # hooks
  333. default_hooks = dict(
  334. checkpoint=dict(save_best='AUC', rule='greater', max_keep_ckpts=1))
  335. custom_hooks = [
  336. dict(
  337. type='EMAHook',
  338. ema_type='ExpMomentumEMA',
  339. momentum=0.0002,
  340. update_buffers=True,
  341. priority=49),
  342. dict(
  343. type='mmdet.PipelineSwitchHook',
  344. switch_epoch=max_epochs - stage2_num_epochs,
  345. switch_pipeline=train_pipeline_stage2)
  346. ]
  347. # evaluators
  348. val_evaluator = [
  349. dict(type='PCKAccuracy', thr=0.2),
  350. dict(type='AUC'),
  351. dict(type='EPE')
  352. ]
  353. test_evaluator = val_evaluator