123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import warnings
- from collections import OrderedDict
- from copy import deepcopy
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.utils.checkpoint as cp
- from mmcv.cnn import build_norm_layer
- from mmcv.cnn.bricks.transformer import FFN, build_dropout
- from mmengine.logging import MMLogger
- from mmengine.model import BaseModule, ModuleList
- from mmengine.model.weight_init import (constant_init, trunc_normal_,
- trunc_normal_init)
- from mmengine.runner.checkpoint import CheckpointLoader
- from mmengine.utils import to_2tuple
- from mmdet.registry import MODELS
- from ..layers import PatchEmbed, PatchMerging
- class WindowMSA(BaseModule):
- """Window based multi-head self-attention (W-MSA) module with relative
- position bias.
- Args:
- embed_dims (int): Number of input channels.
- num_heads (int): Number of attention heads.
- window_size (tuple[int]): The height and width of the window.
- qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
- Default: True.
- qk_scale (float | None, optional): Override default qk scale of
- head_dim ** -0.5 if set. Default: None.
- attn_drop_rate (float, optional): Dropout ratio of attention weight.
- Default: 0.0
- proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
- init_cfg (dict | None, optional): The Config for initialization.
- Default: None.
- """
- def __init__(self,
- embed_dims,
- num_heads,
- window_size,
- qkv_bias=True,
- qk_scale=None,
- attn_drop_rate=0.,
- proj_drop_rate=0.,
- init_cfg=None):
- super().__init__()
- self.embed_dims = embed_dims
- self.window_size = window_size # Wh, Ww
- self.num_heads = num_heads
- head_embed_dims = embed_dims // num_heads
- self.scale = qk_scale or head_embed_dims**-0.5
- self.init_cfg = init_cfg
- # define a parameter table of relative position bias
- self.relative_position_bias_table = nn.Parameter(
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
- num_heads)) # 2*Wh-1 * 2*Ww-1, nH
- # About 2x faster than original impl
- Wh, Ww = self.window_size
- rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
- rel_position_index = rel_index_coords + rel_index_coords.T
- rel_position_index = rel_position_index.flip(1).contiguous()
- self.register_buffer('relative_position_index', rel_position_index)
- self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop_rate)
- self.proj = nn.Linear(embed_dims, embed_dims)
- self.proj_drop = nn.Dropout(proj_drop_rate)
- self.softmax = nn.Softmax(dim=-1)
- def init_weights(self):
- trunc_normal_(self.relative_position_bias_table, std=0.02)
- def forward(self, x, mask=None):
- """
- Args:
- x (tensor): input features with shape of (num_windows*B, N, C)
- mask (tensor | None, Optional): mask with shape of (num_windows,
- Wh*Ww, Wh*Ww), value should be between (-inf, 0].
- """
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
- C // self.num_heads).permute(2, 0, 3, 1, 4)
- # make torchscript happy (cannot use tensor as tuple)
- q, k, v = qkv[0], qkv[1], qkv[2]
- q = q * self.scale
- attn = (q @ k.transpose(-2, -1))
- relative_position_bias = self.relative_position_bias_table[
- self.relative_position_index.view(-1)].view(
- self.window_size[0] * self.window_size[1],
- self.window_size[0] * self.window_size[1],
- -1) # Wh*Ww,Wh*Ww,nH
- relative_position_bias = relative_position_bias.permute(
- 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
- attn = attn + relative_position_bias.unsqueeze(0)
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B // nW, nW, self.num_heads, N,
- N) + mask.unsqueeze(1).unsqueeze(0)
- attn = attn.view(-1, self.num_heads, N, N)
- attn = self.softmax(attn)
- attn = self.attn_drop(attn)
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
- @staticmethod
- def double_step_seq(step1, len1, step2, len2):
- seq1 = torch.arange(0, step1 * len1, step1)
- seq2 = torch.arange(0, step2 * len2, step2)
- return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
- class ShiftWindowMSA(BaseModule):
- """Shifted Window Multihead Self-Attention Module.
- Args:
- embed_dims (int): Number of input channels.
- num_heads (int): Number of attention heads.
- window_size (int): The height and width of the window.
- shift_size (int, optional): The shift step of each window towards
- right-bottom. If zero, act as regular window-msa. Defaults to 0.
- qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
- Default: True
- qk_scale (float | None, optional): Override default qk scale of
- head_dim ** -0.5 if set. Defaults: None.
- attn_drop_rate (float, optional): Dropout ratio of attention weight.
- Defaults: 0.
- proj_drop_rate (float, optional): Dropout ratio of output.
- Defaults: 0.
- dropout_layer (dict, optional): The dropout_layer used before output.
- Defaults: dict(type='DropPath', drop_prob=0.).
- init_cfg (dict, optional): The extra config for initialization.
- Default: None.
- """
- def __init__(self,
- embed_dims,
- num_heads,
- window_size,
- shift_size=0,
- qkv_bias=True,
- qk_scale=None,
- attn_drop_rate=0,
- proj_drop_rate=0,
- dropout_layer=dict(type='DropPath', drop_prob=0.),
- init_cfg=None):
- super().__init__(init_cfg)
- self.window_size = window_size
- self.shift_size = shift_size
- assert 0 <= self.shift_size < self.window_size
- self.w_msa = WindowMSA(
- embed_dims=embed_dims,
- num_heads=num_heads,
- window_size=to_2tuple(window_size),
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- attn_drop_rate=attn_drop_rate,
- proj_drop_rate=proj_drop_rate,
- init_cfg=None)
- self.drop = build_dropout(dropout_layer)
- def forward(self, query, hw_shape):
- B, L, C = query.shape
- H, W = hw_shape
- assert L == H * W, 'input feature has wrong size'
- query = query.view(B, H, W, C)
- # pad feature maps to multiples of window size
- pad_r = (self.window_size - W % self.window_size) % self.window_size
- pad_b = (self.window_size - H % self.window_size) % self.window_size
- query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
- H_pad, W_pad = query.shape[1], query.shape[2]
- # cyclic shift
- if self.shift_size > 0:
- shifted_query = torch.roll(
- query,
- shifts=(-self.shift_size, -self.shift_size),
- dims=(1, 2))
- # calculate attention mask for SW-MSA
- img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)
- h_slices = (slice(0, -self.window_size),
- slice(-self.window_size,
- -self.shift_size), slice(-self.shift_size, None))
- w_slices = (slice(0, -self.window_size),
- slice(-self.window_size,
- -self.shift_size), slice(-self.shift_size, None))
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
- # nW, window_size, window_size, 1
- mask_windows = self.window_partition(img_mask)
- mask_windows = mask_windows.view(
- -1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0,
- float(-100.0)).masked_fill(
- attn_mask == 0, float(0.0))
- else:
- shifted_query = query
- attn_mask = None
- # nW*B, window_size, window_size, C
- query_windows = self.window_partition(shifted_query)
- # nW*B, window_size*window_size, C
- query_windows = query_windows.view(-1, self.window_size**2, C)
- # W-MSA/SW-MSA (nW*B, window_size*window_size, C)
- attn_windows = self.w_msa(query_windows, mask=attn_mask)
- # merge windows
- attn_windows = attn_windows.view(-1, self.window_size,
- self.window_size, C)
- # B H' W' C
- shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
- # reverse cyclic shift
- if self.shift_size > 0:
- x = torch.roll(
- shifted_x,
- shifts=(self.shift_size, self.shift_size),
- dims=(1, 2))
- else:
- x = shifted_x
- if pad_r > 0 or pad_b:
- x = x[:, :H, :W, :].contiguous()
- x = x.view(B, H * W, C)
- x = self.drop(x)
- return x
- def window_reverse(self, windows, H, W):
- """
- Args:
- windows: (num_windows*B, window_size, window_size, C)
- H (int): Height of image
- W (int): Width of image
- Returns:
- x: (B, H, W, C)
- """
- window_size = self.window_size
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(B, H // window_size, W // window_size, window_size,
- window_size, -1)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
- return x
- def window_partition(self, x):
- """
- Args:
- x: (B, H, W, C)
- Returns:
- windows: (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- window_size = self.window_size
- x = x.view(B, H // window_size, window_size, W // window_size,
- window_size, C)
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
- windows = windows.view(-1, window_size, window_size, C)
- return windows
- class SwinBlock(BaseModule):
- """"
- Args:
- embed_dims (int): The feature dimension.
- num_heads (int): Parallel attention heads.
- feedforward_channels (int): The hidden dimension for FFNs.
- window_size (int, optional): The local window scale. Default: 7.
- shift (bool, optional): whether to shift window or not. Default False.
- qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
- qk_scale (float | None, optional): Override default qk scale of
- head_dim ** -0.5 if set. Default: None.
- drop_rate (float, optional): Dropout rate. Default: 0.
- attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
- drop_path_rate (float, optional): Stochastic depth rate. Default: 0.
- act_cfg (dict, optional): The config dict of activation function.
- Default: dict(type='GELU').
- norm_cfg (dict, optional): The config dict of normalization.
- Default: dict(type='LN').
- with_cp (bool, optional): Use checkpoint or not. Using checkpoint
- will save some memory while slowing down the training speed.
- Default: False.
- init_cfg (dict | list | None, optional): The init config.
- Default: None.
- """
- def __init__(self,
- embed_dims,
- num_heads,
- feedforward_channels,
- window_size=7,
- shift=False,
- qkv_bias=True,
- qk_scale=None,
- drop_rate=0.,
- attn_drop_rate=0.,
- drop_path_rate=0.,
- act_cfg=dict(type='GELU'),
- norm_cfg=dict(type='LN'),
- with_cp=False,
- init_cfg=None):
- super(SwinBlock, self).__init__()
- self.init_cfg = init_cfg
- self.with_cp = with_cp
- self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
- self.attn = ShiftWindowMSA(
- embed_dims=embed_dims,
- num_heads=num_heads,
- window_size=window_size,
- shift_size=window_size // 2 if shift else 0,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- attn_drop_rate=attn_drop_rate,
- proj_drop_rate=drop_rate,
- dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
- init_cfg=None)
- self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
- self.ffn = FFN(
- embed_dims=embed_dims,
- feedforward_channels=feedforward_channels,
- num_fcs=2,
- ffn_drop=drop_rate,
- dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
- act_cfg=act_cfg,
- add_identity=True,
- init_cfg=None)
- def forward(self, x, hw_shape):
- def _inner_forward(x):
- identity = x
- x = self.norm1(x)
- x = self.attn(x, hw_shape)
- x = x + identity
- identity = x
- x = self.norm2(x)
- x = self.ffn(x, identity=identity)
- return x
- if self.with_cp and x.requires_grad:
- x = cp.checkpoint(_inner_forward, x)
- else:
- x = _inner_forward(x)
- return x
- class SwinBlockSequence(BaseModule):
- """Implements one stage in Swin Transformer.
- Args:
- embed_dims (int): The feature dimension.
- num_heads (int): Parallel attention heads.
- feedforward_channels (int): The hidden dimension for FFNs.
- depth (int): The number of blocks in this stage.
- window_size (int, optional): The local window scale. Default: 7.
- qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
- qk_scale (float | None, optional): Override default qk scale of
- head_dim ** -0.5 if set. Default: None.
- drop_rate (float, optional): Dropout rate. Default: 0.
- attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
- drop_path_rate (float | list[float], optional): Stochastic depth
- rate. Default: 0.
- downsample (BaseModule | None, optional): The downsample operation
- module. Default: None.
- act_cfg (dict, optional): The config dict of activation function.
- Default: dict(type='GELU').
- norm_cfg (dict, optional): The config dict of normalization.
- Default: dict(type='LN').
- with_cp (bool, optional): Use checkpoint or not. Using checkpoint
- will save some memory while slowing down the training speed.
- Default: False.
- init_cfg (dict | list | None, optional): The init config.
- Default: None.
- """
- def __init__(self,
- embed_dims,
- num_heads,
- feedforward_channels,
- depth,
- window_size=7,
- qkv_bias=True,
- qk_scale=None,
- drop_rate=0.,
- attn_drop_rate=0.,
- drop_path_rate=0.,
- downsample=None,
- act_cfg=dict(type='GELU'),
- norm_cfg=dict(type='LN'),
- with_cp=False,
- init_cfg=None):
- super().__init__(init_cfg=init_cfg)
- if isinstance(drop_path_rate, list):
- drop_path_rates = drop_path_rate
- assert len(drop_path_rates) == depth
- else:
- drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
- self.blocks = ModuleList()
- for i in range(depth):
- block = SwinBlock(
- embed_dims=embed_dims,
- num_heads=num_heads,
- feedforward_channels=feedforward_channels,
- window_size=window_size,
- shift=False if i % 2 == 0 else True,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop_rate=drop_rate,
- attn_drop_rate=attn_drop_rate,
- drop_path_rate=drop_path_rates[i],
- act_cfg=act_cfg,
- norm_cfg=norm_cfg,
- with_cp=with_cp,
- init_cfg=None)
- self.blocks.append(block)
- self.downsample = downsample
- def forward(self, x, hw_shape):
- for block in self.blocks:
- x = block(x, hw_shape)
- if self.downsample:
- x_down, down_hw_shape = self.downsample(x, hw_shape)
- return x_down, down_hw_shape, x, hw_shape
- else:
- return x, hw_shape, x, hw_shape
- @MODELS.register_module()
- class SwinTransformer(BaseModule):
- """ Swin Transformer
- A PyTorch implement of : `Swin Transformer:
- Hierarchical Vision Transformer using Shifted Windows` -
- https://arxiv.org/abs/2103.14030
- Inspiration from
- https://github.com/microsoft/Swin-Transformer
- Args:
- pretrain_img_size (int | tuple[int]): The size of input image when
- pretrain. Defaults: 224.
- in_channels (int): The num of input channels.
- Defaults: 3.
- embed_dims (int): The feature dimension. Default: 96.
- patch_size (int | tuple[int]): Patch size. Default: 4.
- window_size (int): Window size. Default: 7.
- mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
- Default: 4.
- depths (tuple[int]): Depths of each Swin Transformer stage.
- Default: (2, 2, 6, 2).
- num_heads (tuple[int]): Parallel attention heads of each Swin
- Transformer stage. Default: (3, 6, 12, 24).
- strides (tuple[int]): The patch merging or patch embedding stride of
- each Swin Transformer stage. (In swin, we set kernel size equal to
- stride.) Default: (4, 2, 2, 2).
- out_indices (tuple[int]): Output from which stages.
- Default: (0, 1, 2, 3).
- qkv_bias (bool, optional): If True, add a learnable bias to query, key,
- value. Default: True
- qk_scale (float | None, optional): Override default qk scale of
- head_dim ** -0.5 if set. Default: None.
- patch_norm (bool): If add a norm layer for patch embed and patch
- merging. Default: True.
- drop_rate (float): Dropout rate. Defaults: 0.
- attn_drop_rate (float): Attention dropout rate. Default: 0.
- drop_path_rate (float): Stochastic depth rate. Defaults: 0.1.
- use_abs_pos_embed (bool): If True, add absolute position embedding to
- the patch embedding. Defaults: False.
- act_cfg (dict): Config dict for activation layer.
- Default: dict(type='GELU').
- norm_cfg (dict): Config dict for normalization layer at
- output of backone. Defaults: dict(type='LN').
- with_cp (bool, optional): Use checkpoint or not. Using checkpoint
- will save some memory while slowing down the training speed.
- Default: False.
- pretrained (str, optional): model pretrained path. Default: None.
- convert_weights (bool): The flag indicates whether the
- pre-trained model is from the original repo. We may need
- to convert some keys to make it compatible.
- Default: False.
- frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
- Default: -1 (-1 means not freezing any parameters).
- init_cfg (dict, optional): The Config for initialization.
- Defaults to None.
- """
- def __init__(self,
- pretrain_img_size=224,
- in_channels=3,
- embed_dims=96,
- patch_size=4,
- window_size=7,
- mlp_ratio=4,
- depths=(2, 2, 6, 2),
- num_heads=(3, 6, 12, 24),
- strides=(4, 2, 2, 2),
- out_indices=(0, 1, 2, 3),
- qkv_bias=True,
- qk_scale=None,
- patch_norm=True,
- drop_rate=0.,
- attn_drop_rate=0.,
- drop_path_rate=0.1,
- use_abs_pos_embed=False,
- act_cfg=dict(type='GELU'),
- norm_cfg=dict(type='LN'),
- with_cp=False,
- pretrained=None,
- convert_weights=False,
- frozen_stages=-1,
- init_cfg=None):
- self.convert_weights = convert_weights
- self.frozen_stages = frozen_stages
- if isinstance(pretrain_img_size, int):
- pretrain_img_size = to_2tuple(pretrain_img_size)
- elif isinstance(pretrain_img_size, tuple):
- if len(pretrain_img_size) == 1:
- pretrain_img_size = to_2tuple(pretrain_img_size[0])
- assert len(pretrain_img_size) == 2, \
- f'The size of image should have length 1 or 2, ' \
- f'but got {len(pretrain_img_size)}'
- assert not (init_cfg and pretrained), \
- 'init_cfg and pretrained cannot be specified at the same time'
- if isinstance(pretrained, str):
- warnings.warn('DeprecationWarning: pretrained is deprecated, '
- 'please use "init_cfg" instead')
- self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
- elif pretrained is None:
- self.init_cfg = init_cfg
- else:
- raise TypeError('pretrained must be a str or None')
- super(SwinTransformer, self).__init__(init_cfg=init_cfg)
- num_layers = len(depths)
- self.out_indices = out_indices
- self.use_abs_pos_embed = use_abs_pos_embed
- assert strides[0] == patch_size, 'Use non-overlapping patch embed.'
- self.patch_embed = PatchEmbed(
- in_channels=in_channels,
- embed_dims=embed_dims,
- conv_type='Conv2d',
- kernel_size=patch_size,
- stride=strides[0],
- norm_cfg=norm_cfg if patch_norm else None,
- init_cfg=None)
- if self.use_abs_pos_embed:
- patch_row = pretrain_img_size[0] // patch_size
- patch_col = pretrain_img_size[1] // patch_size
- num_patches = patch_row * patch_col
- self.absolute_pos_embed = nn.Parameter(
- torch.zeros((1, num_patches, embed_dims)))
- self.drop_after_pos = nn.Dropout(p=drop_rate)
- # set stochastic depth decay rule
- total_depth = sum(depths)
- dpr = [
- x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
- ]
- self.stages = ModuleList()
- in_channels = embed_dims
- for i in range(num_layers):
- if i < num_layers - 1:
- downsample = PatchMerging(
- in_channels=in_channels,
- out_channels=2 * in_channels,
- stride=strides[i + 1],
- norm_cfg=norm_cfg if patch_norm else None,
- init_cfg=None)
- else:
- downsample = None
- stage = SwinBlockSequence(
- embed_dims=in_channels,
- num_heads=num_heads[i],
- feedforward_channels=mlp_ratio * in_channels,
- depth=depths[i],
- window_size=window_size,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop_rate=drop_rate,
- attn_drop_rate=attn_drop_rate,
- drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])],
- downsample=downsample,
- act_cfg=act_cfg,
- norm_cfg=norm_cfg,
- with_cp=with_cp,
- init_cfg=None)
- self.stages.append(stage)
- if downsample:
- in_channels = downsample.out_channels
- self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)]
- # Add a norm layer for each output
- for i in out_indices:
- layer = build_norm_layer(norm_cfg, self.num_features[i])[1]
- layer_name = f'norm{i}'
- self.add_module(layer_name, layer)
- def train(self, mode=True):
- """Convert the model into training mode while keep layers freezed."""
- super(SwinTransformer, self).train(mode)
- self._freeze_stages()
- def _freeze_stages(self):
- if self.frozen_stages >= 0:
- self.patch_embed.eval()
- for param in self.patch_embed.parameters():
- param.requires_grad = False
- if self.use_abs_pos_embed:
- self.absolute_pos_embed.requires_grad = False
- self.drop_after_pos.eval()
- for i in range(1, self.frozen_stages + 1):
- if (i - 1) in self.out_indices:
- norm_layer = getattr(self, f'norm{i-1}')
- norm_layer.eval()
- for param in norm_layer.parameters():
- param.requires_grad = False
- m = self.stages[i - 1]
- m.eval()
- for param in m.parameters():
- param.requires_grad = False
- def init_weights(self):
- logger = MMLogger.get_current_instance()
- if self.init_cfg is None:
- logger.warn(f'No pre-trained weights for '
- f'{self.__class__.__name__}, '
- f'training start from scratch')
- if self.use_abs_pos_embed:
- trunc_normal_(self.absolute_pos_embed, std=0.02)
- for m in self.modules():
- if isinstance(m, nn.Linear):
- trunc_normal_init(m, std=.02, bias=0.)
- elif isinstance(m, nn.LayerNorm):
- constant_init(m, 1.0)
- else:
- assert 'checkpoint' in self.init_cfg, f'Only support ' \
- f'specify `Pretrained` in ' \
- f'`init_cfg` in ' \
- f'{self.__class__.__name__} '
- ckpt = CheckpointLoader.load_checkpoint(
- self.init_cfg.checkpoint, logger=logger, map_location='cpu')
- if 'state_dict' in ckpt:
- _state_dict = ckpt['state_dict']
- elif 'model' in ckpt:
- _state_dict = ckpt['model']
- else:
- _state_dict = ckpt
- if self.convert_weights:
- # supported loading weight from original repo,
- _state_dict = swin_converter(_state_dict)
- state_dict = OrderedDict()
- for k, v in _state_dict.items():
- if k.startswith('backbone.'):
- state_dict[k[9:]] = v
- # strip prefix of state_dict
- if list(state_dict.keys())[0].startswith('module.'):
- state_dict = {k[7:]: v for k, v in state_dict.items()}
- # reshape absolute position embedding
- if state_dict.get('absolute_pos_embed') is not None:
- absolute_pos_embed = state_dict['absolute_pos_embed']
- N1, L, C1 = absolute_pos_embed.size()
- N2, C2, H, W = self.absolute_pos_embed.size()
- if N1 != N2 or C1 != C2 or L != H * W:
- logger.warning('Error in loading absolute_pos_embed, pass')
- else:
- state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
- N2, H, W, C2).permute(0, 3, 1, 2).contiguous()
- # interpolate position bias table if needed
- relative_position_bias_table_keys = [
- k for k in state_dict.keys()
- if 'relative_position_bias_table' in k
- ]
- for table_key in relative_position_bias_table_keys:
- table_pretrained = state_dict[table_key]
- table_current = self.state_dict()[table_key]
- L1, nH1 = table_pretrained.size()
- L2, nH2 = table_current.size()
- if nH1 != nH2:
- logger.warning(f'Error in loading {table_key}, pass')
- elif L1 != L2:
- S1 = int(L1**0.5)
- S2 = int(L2**0.5)
- table_pretrained_resized = F.interpolate(
- table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),
- size=(S2, S2),
- mode='bicubic')
- state_dict[table_key] = table_pretrained_resized.view(
- nH2, L2).permute(1, 0).contiguous()
- # load state_dict
- self.load_state_dict(state_dict, False)
- def forward(self, x):
- x, hw_shape = self.patch_embed(x)
- if self.use_abs_pos_embed:
- x = x + self.absolute_pos_embed
- x = self.drop_after_pos(x)
- outs = []
- for i, stage in enumerate(self.stages):
- x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
- if i in self.out_indices:
- norm_layer = getattr(self, f'norm{i}')
- out = norm_layer(out)
- out = out.view(-1, *out_hw_shape,
- self.num_features[i]).permute(0, 3, 1,
- 2).contiguous()
- outs.append(out)
- return outs
- def swin_converter(ckpt):
- new_ckpt = OrderedDict()
- def correct_unfold_reduction_order(x):
- out_channel, in_channel = x.shape
- x = x.reshape(out_channel, 4, in_channel // 4)
- x = x[:, [0, 2, 1, 3], :].transpose(1,
- 2).reshape(out_channel, in_channel)
- return x
- def correct_unfold_norm_order(x):
- in_channel = x.shape[0]
- x = x.reshape(4, in_channel // 4)
- x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
- return x
- for k, v in ckpt.items():
- if k.startswith('head'):
- continue
- elif k.startswith('layers'):
- new_v = v
- if 'attn.' in k:
- new_k = k.replace('attn.', 'attn.w_msa.')
- elif 'mlp.' in k:
- if 'mlp.fc1.' in k:
- new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
- elif 'mlp.fc2.' in k:
- new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
- else:
- new_k = k.replace('mlp.', 'ffn.')
- elif 'downsample' in k:
- new_k = k
- if 'reduction.' in k:
- new_v = correct_unfold_reduction_order(v)
- elif 'norm.' in k:
- new_v = correct_unfold_norm_order(v)
- else:
- new_k = k
- new_k = new_k.replace('layers', 'stages', 1)
- elif k.startswith('patch_embed'):
- new_v = v
- if 'proj' in k:
- new_k = k.replace('proj', 'projection')
- else:
- new_k = k
- else:
- new_v = v
- new_k = k
- new_ckpt['backbone.' + new_k] = new_v
- return new_ckpt
|