123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272 |
- _base_ = ['../../../_base_/default_runtime.py']
- # runtime
- max_epochs = 420
- stage2_num_epochs = 30
- base_lr = 4e-3
- train_cfg = dict(max_epochs=max_epochs, val_interval=10)
- randomness = dict(seed=21)
- # optimizer
- optim_wrapper = dict(
- type='OptimWrapper',
- optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
- paramwise_cfg=dict(
- norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
- # learning rate
- param_scheduler = [
- dict(
- type='LinearLR',
- start_factor=1.0e-5,
- by_epoch=False,
- begin=0,
- end=1000),
- dict(
- # use cosine lr from 210 to 420 epoch
- type='CosineAnnealingLR',
- eta_min=base_lr * 0.05,
- begin=max_epochs // 2,
- end=max_epochs,
- T_max=max_epochs // 2,
- by_epoch=True,
- convert_to_iter_based=True),
- ]
- # automatically scaling LR based on the actual training batch size
- auto_scale_lr = dict(base_batch_size=1024)
- # codec settings
- codec = dict(
- type='SimCCLabel',
- input_size=(288, 384),
- sigma=(6., 6.93),
- simcc_split_ratio=2.0,
- normalize=False,
- use_dark=False)
- # model settings
- model = dict(
- type='TopdownPoseEstimator',
- data_preprocessor=dict(
- type='PoseDataPreprocessor',
- mean=[123.675, 116.28, 103.53],
- std=[58.395, 57.12, 57.375],
- bgr_to_rgb=True),
- backbone=dict(
- _scope_='mmdet',
- type='CSPNeXt',
- arch='P5',
- expand_ratio=0.5,
- deepen_factor=1.,
- widen_factor=1.,
- out_indices=(4, ),
- channel_attention=True,
- norm_cfg=dict(type='SyncBN'),
- act_cfg=dict(type='SiLU'),
- init_cfg=dict(
- type='Pretrained',
- prefix='backbone.',
- checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
- 'rtmposev1/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa
- )),
- head=dict(
- type='RTMCCHead',
- in_channels=1024,
- out_channels=17,
- input_size=codec['input_size'],
- in_featuremap_size=(9, 12),
- simcc_split_ratio=codec['simcc_split_ratio'],
- final_layer_kernel_size=7,
- gau_cfg=dict(
- hidden_dims=256,
- s=128,
- expansion_factor=2,
- dropout_rate=0.,
- drop_path=0.,
- act_fn='SiLU',
- use_rel_bias=False,
- pos_enc=False),
- loss=dict(
- type='KLDiscretLoss',
- use_target_weight=True,
- beta=10.,
- label_softmax=True),
- decoder=codec),
- test_cfg=dict(flip_test=True, ))
- # base dataset settings
- dataset_type = 'CocoDataset'
- data_mode = 'topdown'
- data_root = 'data/'
- backend_args = dict(backend='local')
- # backend_args = dict(
- # backend='petrel',
- # path_mapping=dict({
- # f'{data_root}': 's3://openmmlab/datasets/',
- # f'{data_root}': 's3://openmmlab/datasets/'
- # }))
- # pipelines
- train_pipeline = [
- dict(type='LoadImage', backend_args=backend_args),
- dict(type='GetBBoxCenterScale'),
- dict(type='RandomFlip', direction='horizontal'),
- dict(type='RandomHalfBody'),
- dict(
- type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80),
- dict(type='TopdownAffine', input_size=codec['input_size']),
- dict(type='mmdet.YOLOXHSVRandomAug'),
- dict(
- type='Albumentation',
- transforms=[
- dict(type='Blur', p=0.1),
- dict(type='MedianBlur', p=0.1),
- dict(
- type='CoarseDropout',
- max_holes=1,
- max_height=0.4,
- max_width=0.4,
- min_holes=1,
- min_height=0.2,
- min_width=0.2,
- p=1.0),
- ]),
- dict(type='GenerateTarget', encoder=codec),
- dict(type='PackPoseInputs')
- ]
- val_pipeline = [
- dict(type='LoadImage', backend_args=backend_args),
- dict(type='GetBBoxCenterScale'),
- dict(type='TopdownAffine', input_size=codec['input_size']),
- dict(type='PackPoseInputs')
- ]
- train_pipeline_stage2 = [
- dict(type='LoadImage', backend_args=backend_args),
- dict(type='GetBBoxCenterScale'),
- dict(type='RandomFlip', direction='horizontal'),
- dict(type='RandomHalfBody'),
- dict(
- type='RandomBBoxTransform',
- shift_factor=0.,
- scale_factor=[0.75, 1.25],
- rotate_factor=60),
- dict(type='TopdownAffine', input_size=codec['input_size']),
- dict(type='mmdet.YOLOXHSVRandomAug'),
- dict(
- type='Albumentation',
- transforms=[
- dict(type='Blur', p=0.1),
- dict(type='MedianBlur', p=0.1),
- dict(
- type='CoarseDropout',
- max_holes=1,
- max_height=0.4,
- max_width=0.4,
- min_holes=1,
- min_height=0.2,
- min_width=0.2,
- p=0.5),
- ]),
- dict(type='GenerateTarget', encoder=codec),
- dict(type='PackPoseInputs')
- ]
- # train datasets
- dataset_coco = dict(
- type='RepeatDataset',
- dataset=dict(
- type=dataset_type,
- data_root=data_root,
- data_mode=data_mode,
- ann_file='coco/annotations/person_keypoints_train2017.json',
- data_prefix=dict(img='detection/coco/train2017/'),
- pipeline=[],
- ),
- times=3)
- dataset_aic = dict(
- type='AicDataset',
- data_root=data_root,
- data_mode=data_mode,
- ann_file='aic/annotations/aic_train.json',
- data_prefix=dict(img='pose/ai_challenge/ai_challenger_keypoint'
- '_train_20170902/keypoint_train_images_20170902/'),
- pipeline=[
- dict(
- type='KeypointConverter',
- num_keypoints=17,
- mapping=[
- (0, 6),
- (1, 8),
- (2, 10),
- (3, 5),
- (4, 7),
- (5, 9),
- (6, 12),
- (7, 14),
- (8, 16),
- (9, 11),
- (10, 13),
- (11, 15),
- ])
- ],
- )
- # data loaders
- train_dataloader = dict(
- batch_size=256,
- num_workers=10,
- persistent_workers=True,
- sampler=dict(type='DefaultSampler', shuffle=True),
- dataset=dict(
- type='CombinedDataset',
- metainfo=dict(from_file='configs/_base_/datasets/coco.py'),
- datasets=[dataset_coco, dataset_aic],
- pipeline=train_pipeline,
- test_mode=False,
- ))
- val_dataloader = dict(
- batch_size=64,
- num_workers=10,
- persistent_workers=True,
- drop_last=False,
- sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
- dataset=dict(
- type=dataset_type,
- data_root=data_root,
- data_mode=data_mode,
- ann_file='coco/annotations/person_keypoints_val2017.json',
- # bbox_file='data/coco/person_detection_results/'
- # 'COCO_val2017_detections_AP_H_56_person.json',
- data_prefix=dict(img='detection/coco/val2017/'),
- test_mode=True,
- pipeline=val_pipeline,
- ))
- test_dataloader = val_dataloader
- # hooks
- default_hooks = dict(
- checkpoint=dict(save_best='coco/AP', rule='greater', max_keep_ckpts=1))
- custom_hooks = [
- dict(
- type='EMAHook',
- ema_type='ExpMomentumEMA',
- momentum=0.0002,
- update_buffers=True,
- priority=49),
- dict(
- type='mmdet.PipelineSwitchHook',
- switch_epoch=max_epochs - stage2_num_epochs,
- switch_pipeline=train_pipeline_stage2)
- ]
- # evaluators
- val_evaluator = dict(
- type='CocoMetric',
- ann_file=data_root + 'coco/annotations/person_keypoints_val2017.json')
- test_evaluator = val_evaluator
|