mask2former_swin-t-p4-w7-224_8xb2-lsj-50e_coco-panoptic.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. _base_ = ['./mask2former_r50_8xb2-lsj-50e_coco-panoptic.py']
  2. pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' # noqa
  3. depths = [2, 2, 6, 2]
  4. model = dict(
  5. type='Mask2Former',
  6. backbone=dict(
  7. _delete_=True,
  8. type='SwinTransformer',
  9. embed_dims=96,
  10. depths=depths,
  11. num_heads=[3, 6, 12, 24],
  12. window_size=7,
  13. mlp_ratio=4,
  14. qkv_bias=True,
  15. qk_scale=None,
  16. drop_rate=0.,
  17. attn_drop_rate=0.,
  18. drop_path_rate=0.3,
  19. patch_norm=True,
  20. out_indices=(0, 1, 2, 3),
  21. with_cp=False,
  22. convert_weights=True,
  23. frozen_stages=-1,
  24. init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
  25. panoptic_head=dict(
  26. type='Mask2FormerHead', in_channels=[96, 192, 384, 768]),
  27. init_cfg=None)
  28. # set all layers in backbone to lr_mult=0.1
  29. # set all norm layers, position_embeding,
  30. # query_embeding, level_embeding to decay_multi=0.0
  31. backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
  32. backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0)
  33. embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
  34. custom_keys = {
  35. 'backbone': dict(lr_mult=0.1, decay_mult=1.0),
  36. 'backbone.patch_embed.norm': backbone_norm_multi,
  37. 'backbone.norm': backbone_norm_multi,
  38. 'absolute_pos_embed': backbone_embed_multi,
  39. 'relative_position_bias_table': backbone_embed_multi,
  40. 'query_embed': embed_multi,
  41. 'query_feat': embed_multi,
  42. 'level_embed': embed_multi
  43. }
  44. custom_keys.update({
  45. f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
  46. for stage_id, num_blocks in enumerate(depths)
  47. for block_id in range(num_blocks)
  48. })
  49. custom_keys.update({
  50. f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
  51. for stage_id in range(len(depths) - 1)
  52. })
  53. # optimizer
  54. optim_wrapper = dict(
  55. paramwise_cfg=dict(custom_keys=custom_keys, norm_decay_mult=0.0))