swin.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. from collections import OrderedDict
  4. from copy import deepcopy
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import torch.utils.checkpoint as cp
  9. from mmcv.cnn import build_norm_layer
  10. from mmcv.cnn.bricks.transformer import FFN, build_dropout
  11. from mmengine.logging import MMLogger
  12. from mmengine.model import BaseModule, ModuleList
  13. from mmengine.model.weight_init import (constant_init, trunc_normal_,
  14. trunc_normal_init)
  15. from mmengine.runner.checkpoint import CheckpointLoader
  16. from mmengine.utils import to_2tuple
  17. from mmdet.registry import MODELS
  18. from ..layers import PatchEmbed, PatchMerging
  19. class WindowMSA(BaseModule):
  20. """Window based multi-head self-attention (W-MSA) module with relative
  21. position bias.
  22. Args:
  23. embed_dims (int): Number of input channels.
  24. num_heads (int): Number of attention heads.
  25. window_size (tuple[int]): The height and width of the window.
  26. qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
  27. Default: True.
  28. qk_scale (float | None, optional): Override default qk scale of
  29. head_dim ** -0.5 if set. Default: None.
  30. attn_drop_rate (float, optional): Dropout ratio of attention weight.
  31. Default: 0.0
  32. proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
  33. init_cfg (dict | None, optional): The Config for initialization.
  34. Default: None.
  35. """
  36. def __init__(self,
  37. embed_dims,
  38. num_heads,
  39. window_size,
  40. qkv_bias=True,
  41. qk_scale=None,
  42. attn_drop_rate=0.,
  43. proj_drop_rate=0.,
  44. init_cfg=None):
  45. super().__init__()
  46. self.embed_dims = embed_dims
  47. self.window_size = window_size # Wh, Ww
  48. self.num_heads = num_heads
  49. head_embed_dims = embed_dims // num_heads
  50. self.scale = qk_scale or head_embed_dims**-0.5
  51. self.init_cfg = init_cfg
  52. # define a parameter table of relative position bias
  53. self.relative_position_bias_table = nn.Parameter(
  54. torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
  55. num_heads)) # 2*Wh-1 * 2*Ww-1, nH
  56. # About 2x faster than original impl
  57. Wh, Ww = self.window_size
  58. rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
  59. rel_position_index = rel_index_coords + rel_index_coords.T
  60. rel_position_index = rel_position_index.flip(1).contiguous()
  61. self.register_buffer('relative_position_index', rel_position_index)
  62. self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
  63. self.attn_drop = nn.Dropout(attn_drop_rate)
  64. self.proj = nn.Linear(embed_dims, embed_dims)
  65. self.proj_drop = nn.Dropout(proj_drop_rate)
  66. self.softmax = nn.Softmax(dim=-1)
  67. def init_weights(self):
  68. trunc_normal_(self.relative_position_bias_table, std=0.02)
  69. def forward(self, x, mask=None):
  70. """
  71. Args:
  72. x (tensor): input features with shape of (num_windows*B, N, C)
  73. mask (tensor | None, Optional): mask with shape of (num_windows,
  74. Wh*Ww, Wh*Ww), value should be between (-inf, 0].
  75. """
  76. B, N, C = x.shape
  77. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
  78. C // self.num_heads).permute(2, 0, 3, 1, 4)
  79. # make torchscript happy (cannot use tensor as tuple)
  80. q, k, v = qkv[0], qkv[1], qkv[2]
  81. q = q * self.scale
  82. attn = (q @ k.transpose(-2, -1))
  83. relative_position_bias = self.relative_position_bias_table[
  84. self.relative_position_index.view(-1)].view(
  85. self.window_size[0] * self.window_size[1],
  86. self.window_size[0] * self.window_size[1],
  87. -1) # Wh*Ww,Wh*Ww,nH
  88. relative_position_bias = relative_position_bias.permute(
  89. 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  90. attn = attn + relative_position_bias.unsqueeze(0)
  91. if mask is not None:
  92. nW = mask.shape[0]
  93. attn = attn.view(B // nW, nW, self.num_heads, N,
  94. N) + mask.unsqueeze(1).unsqueeze(0)
  95. attn = attn.view(-1, self.num_heads, N, N)
  96. attn = self.softmax(attn)
  97. attn = self.attn_drop(attn)
  98. x = (attn @ v).transpose(1, 2).reshape(B, N, C)
  99. x = self.proj(x)
  100. x = self.proj_drop(x)
  101. return x
  102. @staticmethod
  103. def double_step_seq(step1, len1, step2, len2):
  104. seq1 = torch.arange(0, step1 * len1, step1)
  105. seq2 = torch.arange(0, step2 * len2, step2)
  106. return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
  107. class ShiftWindowMSA(BaseModule):
  108. """Shifted Window Multihead Self-Attention Module.
  109. Args:
  110. embed_dims (int): Number of input channels.
  111. num_heads (int): Number of attention heads.
  112. window_size (int): The height and width of the window.
  113. shift_size (int, optional): The shift step of each window towards
  114. right-bottom. If zero, act as regular window-msa. Defaults to 0.
  115. qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
  116. Default: True
  117. qk_scale (float | None, optional): Override default qk scale of
  118. head_dim ** -0.5 if set. Defaults: None.
  119. attn_drop_rate (float, optional): Dropout ratio of attention weight.
  120. Defaults: 0.
  121. proj_drop_rate (float, optional): Dropout ratio of output.
  122. Defaults: 0.
  123. dropout_layer (dict, optional): The dropout_layer used before output.
  124. Defaults: dict(type='DropPath', drop_prob=0.).
  125. init_cfg (dict, optional): The extra config for initialization.
  126. Default: None.
  127. """
  128. def __init__(self,
  129. embed_dims,
  130. num_heads,
  131. window_size,
  132. shift_size=0,
  133. qkv_bias=True,
  134. qk_scale=None,
  135. attn_drop_rate=0,
  136. proj_drop_rate=0,
  137. dropout_layer=dict(type='DropPath', drop_prob=0.),
  138. init_cfg=None):
  139. super().__init__(init_cfg)
  140. self.window_size = window_size
  141. self.shift_size = shift_size
  142. assert 0 <= self.shift_size < self.window_size
  143. self.w_msa = WindowMSA(
  144. embed_dims=embed_dims,
  145. num_heads=num_heads,
  146. window_size=to_2tuple(window_size),
  147. qkv_bias=qkv_bias,
  148. qk_scale=qk_scale,
  149. attn_drop_rate=attn_drop_rate,
  150. proj_drop_rate=proj_drop_rate,
  151. init_cfg=None)
  152. self.drop = build_dropout(dropout_layer)
  153. def forward(self, query, hw_shape):
  154. B, L, C = query.shape
  155. H, W = hw_shape
  156. assert L == H * W, 'input feature has wrong size'
  157. query = query.view(B, H, W, C)
  158. # pad feature maps to multiples of window size
  159. pad_r = (self.window_size - W % self.window_size) % self.window_size
  160. pad_b = (self.window_size - H % self.window_size) % self.window_size
  161. query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
  162. H_pad, W_pad = query.shape[1], query.shape[2]
  163. # cyclic shift
  164. if self.shift_size > 0:
  165. shifted_query = torch.roll(
  166. query,
  167. shifts=(-self.shift_size, -self.shift_size),
  168. dims=(1, 2))
  169. # calculate attention mask for SW-MSA
  170. img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)
  171. h_slices = (slice(0, -self.window_size),
  172. slice(-self.window_size,
  173. -self.shift_size), slice(-self.shift_size, None))
  174. w_slices = (slice(0, -self.window_size),
  175. slice(-self.window_size,
  176. -self.shift_size), slice(-self.shift_size, None))
  177. cnt = 0
  178. for h in h_slices:
  179. for w in w_slices:
  180. img_mask[:, h, w, :] = cnt
  181. cnt += 1
  182. # nW, window_size, window_size, 1
  183. mask_windows = self.window_partition(img_mask)
  184. mask_windows = mask_windows.view(
  185. -1, self.window_size * self.window_size)
  186. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  187. attn_mask = attn_mask.masked_fill(attn_mask != 0,
  188. float(-100.0)).masked_fill(
  189. attn_mask == 0, float(0.0))
  190. else:
  191. shifted_query = query
  192. attn_mask = None
  193. # nW*B, window_size, window_size, C
  194. query_windows = self.window_partition(shifted_query)
  195. # nW*B, window_size*window_size, C
  196. query_windows = query_windows.view(-1, self.window_size**2, C)
  197. # W-MSA/SW-MSA (nW*B, window_size*window_size, C)
  198. attn_windows = self.w_msa(query_windows, mask=attn_mask)
  199. # merge windows
  200. attn_windows = attn_windows.view(-1, self.window_size,
  201. self.window_size, C)
  202. # B H' W' C
  203. shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
  204. # reverse cyclic shift
  205. if self.shift_size > 0:
  206. x = torch.roll(
  207. shifted_x,
  208. shifts=(self.shift_size, self.shift_size),
  209. dims=(1, 2))
  210. else:
  211. x = shifted_x
  212. if pad_r > 0 or pad_b:
  213. x = x[:, :H, :W, :].contiguous()
  214. x = x.view(B, H * W, C)
  215. x = self.drop(x)
  216. return x
  217. def window_reverse(self, windows, H, W):
  218. """
  219. Args:
  220. windows: (num_windows*B, window_size, window_size, C)
  221. H (int): Height of image
  222. W (int): Width of image
  223. Returns:
  224. x: (B, H, W, C)
  225. """
  226. window_size = self.window_size
  227. B = int(windows.shape[0] / (H * W / window_size / window_size))
  228. x = windows.view(B, H // window_size, W // window_size, window_size,
  229. window_size, -1)
  230. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
  231. return x
  232. def window_partition(self, x):
  233. """
  234. Args:
  235. x: (B, H, W, C)
  236. Returns:
  237. windows: (num_windows*B, window_size, window_size, C)
  238. """
  239. B, H, W, C = x.shape
  240. window_size = self.window_size
  241. x = x.view(B, H // window_size, window_size, W // window_size,
  242. window_size, C)
  243. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
  244. windows = windows.view(-1, window_size, window_size, C)
  245. return windows
  246. class SwinBlock(BaseModule):
  247. """"
  248. Args:
  249. embed_dims (int): The feature dimension.
  250. num_heads (int): Parallel attention heads.
  251. feedforward_channels (int): The hidden dimension for FFNs.
  252. window_size (int, optional): The local window scale. Default: 7.
  253. shift (bool, optional): whether to shift window or not. Default False.
  254. qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
  255. qk_scale (float | None, optional): Override default qk scale of
  256. head_dim ** -0.5 if set. Default: None.
  257. drop_rate (float, optional): Dropout rate. Default: 0.
  258. attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
  259. drop_path_rate (float, optional): Stochastic depth rate. Default: 0.
  260. act_cfg (dict, optional): The config dict of activation function.
  261. Default: dict(type='GELU').
  262. norm_cfg (dict, optional): The config dict of normalization.
  263. Default: dict(type='LN').
  264. with_cp (bool, optional): Use checkpoint or not. Using checkpoint
  265. will save some memory while slowing down the training speed.
  266. Default: False.
  267. init_cfg (dict | list | None, optional): The init config.
  268. Default: None.
  269. """
  270. def __init__(self,
  271. embed_dims,
  272. num_heads,
  273. feedforward_channels,
  274. window_size=7,
  275. shift=False,
  276. qkv_bias=True,
  277. qk_scale=None,
  278. drop_rate=0.,
  279. attn_drop_rate=0.,
  280. drop_path_rate=0.,
  281. act_cfg=dict(type='GELU'),
  282. norm_cfg=dict(type='LN'),
  283. with_cp=False,
  284. init_cfg=None):
  285. super(SwinBlock, self).__init__()
  286. self.init_cfg = init_cfg
  287. self.with_cp = with_cp
  288. self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
  289. self.attn = ShiftWindowMSA(
  290. embed_dims=embed_dims,
  291. num_heads=num_heads,
  292. window_size=window_size,
  293. shift_size=window_size // 2 if shift else 0,
  294. qkv_bias=qkv_bias,
  295. qk_scale=qk_scale,
  296. attn_drop_rate=attn_drop_rate,
  297. proj_drop_rate=drop_rate,
  298. dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
  299. init_cfg=None)
  300. self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
  301. self.ffn = FFN(
  302. embed_dims=embed_dims,
  303. feedforward_channels=feedforward_channels,
  304. num_fcs=2,
  305. ffn_drop=drop_rate,
  306. dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
  307. act_cfg=act_cfg,
  308. add_identity=True,
  309. init_cfg=None)
  310. def forward(self, x, hw_shape):
  311. def _inner_forward(x):
  312. identity = x
  313. x = self.norm1(x)
  314. x = self.attn(x, hw_shape)
  315. x = x + identity
  316. identity = x
  317. x = self.norm2(x)
  318. x = self.ffn(x, identity=identity)
  319. return x
  320. if self.with_cp and x.requires_grad:
  321. x = cp.checkpoint(_inner_forward, x)
  322. else:
  323. x = _inner_forward(x)
  324. return x
  325. class SwinBlockSequence(BaseModule):
  326. """Implements one stage in Swin Transformer.
  327. Args:
  328. embed_dims (int): The feature dimension.
  329. num_heads (int): Parallel attention heads.
  330. feedforward_channels (int): The hidden dimension for FFNs.
  331. depth (int): The number of blocks in this stage.
  332. window_size (int, optional): The local window scale. Default: 7.
  333. qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
  334. qk_scale (float | None, optional): Override default qk scale of
  335. head_dim ** -0.5 if set. Default: None.
  336. drop_rate (float, optional): Dropout rate. Default: 0.
  337. attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
  338. drop_path_rate (float | list[float], optional): Stochastic depth
  339. rate. Default: 0.
  340. downsample (BaseModule | None, optional): The downsample operation
  341. module. Default: None.
  342. act_cfg (dict, optional): The config dict of activation function.
  343. Default: dict(type='GELU').
  344. norm_cfg (dict, optional): The config dict of normalization.
  345. Default: dict(type='LN').
  346. with_cp (bool, optional): Use checkpoint or not. Using checkpoint
  347. will save some memory while slowing down the training speed.
  348. Default: False.
  349. init_cfg (dict | list | None, optional): The init config.
  350. Default: None.
  351. """
  352. def __init__(self,
  353. embed_dims,
  354. num_heads,
  355. feedforward_channels,
  356. depth,
  357. window_size=7,
  358. qkv_bias=True,
  359. qk_scale=None,
  360. drop_rate=0.,
  361. attn_drop_rate=0.,
  362. drop_path_rate=0.,
  363. downsample=None,
  364. act_cfg=dict(type='GELU'),
  365. norm_cfg=dict(type='LN'),
  366. with_cp=False,
  367. init_cfg=None):
  368. super().__init__(init_cfg=init_cfg)
  369. if isinstance(drop_path_rate, list):
  370. drop_path_rates = drop_path_rate
  371. assert len(drop_path_rates) == depth
  372. else:
  373. drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
  374. self.blocks = ModuleList()
  375. for i in range(depth):
  376. block = SwinBlock(
  377. embed_dims=embed_dims,
  378. num_heads=num_heads,
  379. feedforward_channels=feedforward_channels,
  380. window_size=window_size,
  381. shift=False if i % 2 == 0 else True,
  382. qkv_bias=qkv_bias,
  383. qk_scale=qk_scale,
  384. drop_rate=drop_rate,
  385. attn_drop_rate=attn_drop_rate,
  386. drop_path_rate=drop_path_rates[i],
  387. act_cfg=act_cfg,
  388. norm_cfg=norm_cfg,
  389. with_cp=with_cp,
  390. init_cfg=None)
  391. self.blocks.append(block)
  392. self.downsample = downsample
  393. def forward(self, x, hw_shape):
  394. for block in self.blocks:
  395. x = block(x, hw_shape)
  396. if self.downsample:
  397. x_down, down_hw_shape = self.downsample(x, hw_shape)
  398. return x_down, down_hw_shape, x, hw_shape
  399. else:
  400. return x, hw_shape, x, hw_shape
  401. @MODELS.register_module()
  402. class SwinTransformer(BaseModule):
  403. """ Swin Transformer
  404. A PyTorch implement of : `Swin Transformer:
  405. Hierarchical Vision Transformer using Shifted Windows` -
  406. https://arxiv.org/abs/2103.14030
  407. Inspiration from
  408. https://github.com/microsoft/Swin-Transformer
  409. Args:
  410. pretrain_img_size (int | tuple[int]): The size of input image when
  411. pretrain. Defaults: 224.
  412. in_channels (int): The num of input channels.
  413. Defaults: 3.
  414. embed_dims (int): The feature dimension. Default: 96.
  415. patch_size (int | tuple[int]): Patch size. Default: 4.
  416. window_size (int): Window size. Default: 7.
  417. mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
  418. Default: 4.
  419. depths (tuple[int]): Depths of each Swin Transformer stage.
  420. Default: (2, 2, 6, 2).
  421. num_heads (tuple[int]): Parallel attention heads of each Swin
  422. Transformer stage. Default: (3, 6, 12, 24).
  423. strides (tuple[int]): The patch merging or patch embedding stride of
  424. each Swin Transformer stage. (In swin, we set kernel size equal to
  425. stride.) Default: (4, 2, 2, 2).
  426. out_indices (tuple[int]): Output from which stages.
  427. Default: (0, 1, 2, 3).
  428. qkv_bias (bool, optional): If True, add a learnable bias to query, key,
  429. value. Default: True
  430. qk_scale (float | None, optional): Override default qk scale of
  431. head_dim ** -0.5 if set. Default: None.
  432. patch_norm (bool): If add a norm layer for patch embed and patch
  433. merging. Default: True.
  434. drop_rate (float): Dropout rate. Defaults: 0.
  435. attn_drop_rate (float): Attention dropout rate. Default: 0.
  436. drop_path_rate (float): Stochastic depth rate. Defaults: 0.1.
  437. use_abs_pos_embed (bool): If True, add absolute position embedding to
  438. the patch embedding. Defaults: False.
  439. act_cfg (dict): Config dict for activation layer.
  440. Default: dict(type='GELU').
  441. norm_cfg (dict): Config dict for normalization layer at
  442. output of backone. Defaults: dict(type='LN').
  443. with_cp (bool, optional): Use checkpoint or not. Using checkpoint
  444. will save some memory while slowing down the training speed.
  445. Default: False.
  446. pretrained (str, optional): model pretrained path. Default: None.
  447. convert_weights (bool): The flag indicates whether the
  448. pre-trained model is from the original repo. We may need
  449. to convert some keys to make it compatible.
  450. Default: False.
  451. frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
  452. Default: -1 (-1 means not freezing any parameters).
  453. init_cfg (dict, optional): The Config for initialization.
  454. Defaults to None.
  455. """
  456. def __init__(self,
  457. pretrain_img_size=224,
  458. in_channels=3,
  459. embed_dims=96,
  460. patch_size=4,
  461. window_size=7,
  462. mlp_ratio=4,
  463. depths=(2, 2, 6, 2),
  464. num_heads=(3, 6, 12, 24),
  465. strides=(4, 2, 2, 2),
  466. out_indices=(0, 1, 2, 3),
  467. qkv_bias=True,
  468. qk_scale=None,
  469. patch_norm=True,
  470. drop_rate=0.,
  471. attn_drop_rate=0.,
  472. drop_path_rate=0.1,
  473. use_abs_pos_embed=False,
  474. act_cfg=dict(type='GELU'),
  475. norm_cfg=dict(type='LN'),
  476. with_cp=False,
  477. pretrained=None,
  478. convert_weights=False,
  479. frozen_stages=-1,
  480. init_cfg=None):
  481. self.convert_weights = convert_weights
  482. self.frozen_stages = frozen_stages
  483. if isinstance(pretrain_img_size, int):
  484. pretrain_img_size = to_2tuple(pretrain_img_size)
  485. elif isinstance(pretrain_img_size, tuple):
  486. if len(pretrain_img_size) == 1:
  487. pretrain_img_size = to_2tuple(pretrain_img_size[0])
  488. assert len(pretrain_img_size) == 2, \
  489. f'The size of image should have length 1 or 2, ' \
  490. f'but got {len(pretrain_img_size)}'
  491. assert not (init_cfg and pretrained), \
  492. 'init_cfg and pretrained cannot be specified at the same time'
  493. if isinstance(pretrained, str):
  494. warnings.warn('DeprecationWarning: pretrained is deprecated, '
  495. 'please use "init_cfg" instead')
  496. self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
  497. elif pretrained is None:
  498. self.init_cfg = init_cfg
  499. else:
  500. raise TypeError('pretrained must be a str or None')
  501. super(SwinTransformer, self).__init__(init_cfg=init_cfg)
  502. num_layers = len(depths)
  503. self.out_indices = out_indices
  504. self.use_abs_pos_embed = use_abs_pos_embed
  505. assert strides[0] == patch_size, 'Use non-overlapping patch embed.'
  506. self.patch_embed = PatchEmbed(
  507. in_channels=in_channels,
  508. embed_dims=embed_dims,
  509. conv_type='Conv2d',
  510. kernel_size=patch_size,
  511. stride=strides[0],
  512. norm_cfg=norm_cfg if patch_norm else None,
  513. init_cfg=None)
  514. if self.use_abs_pos_embed:
  515. patch_row = pretrain_img_size[0] // patch_size
  516. patch_col = pretrain_img_size[1] // patch_size
  517. num_patches = patch_row * patch_col
  518. self.absolute_pos_embed = nn.Parameter(
  519. torch.zeros((1, num_patches, embed_dims)))
  520. self.drop_after_pos = nn.Dropout(p=drop_rate)
  521. # set stochastic depth decay rule
  522. total_depth = sum(depths)
  523. dpr = [
  524. x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
  525. ]
  526. self.stages = ModuleList()
  527. in_channels = embed_dims
  528. for i in range(num_layers):
  529. if i < num_layers - 1:
  530. downsample = PatchMerging(
  531. in_channels=in_channels,
  532. out_channels=2 * in_channels,
  533. stride=strides[i + 1],
  534. norm_cfg=norm_cfg if patch_norm else None,
  535. init_cfg=None)
  536. else:
  537. downsample = None
  538. stage = SwinBlockSequence(
  539. embed_dims=in_channels,
  540. num_heads=num_heads[i],
  541. feedforward_channels=mlp_ratio * in_channels,
  542. depth=depths[i],
  543. window_size=window_size,
  544. qkv_bias=qkv_bias,
  545. qk_scale=qk_scale,
  546. drop_rate=drop_rate,
  547. attn_drop_rate=attn_drop_rate,
  548. drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])],
  549. downsample=downsample,
  550. act_cfg=act_cfg,
  551. norm_cfg=norm_cfg,
  552. with_cp=with_cp,
  553. init_cfg=None)
  554. self.stages.append(stage)
  555. if downsample:
  556. in_channels = downsample.out_channels
  557. self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)]
  558. # Add a norm layer for each output
  559. for i in out_indices:
  560. layer = build_norm_layer(norm_cfg, self.num_features[i])[1]
  561. layer_name = f'norm{i}'
  562. self.add_module(layer_name, layer)
  563. def train(self, mode=True):
  564. """Convert the model into training mode while keep layers freezed."""
  565. super(SwinTransformer, self).train(mode)
  566. self._freeze_stages()
  567. def _freeze_stages(self):
  568. if self.frozen_stages >= 0:
  569. self.patch_embed.eval()
  570. for param in self.patch_embed.parameters():
  571. param.requires_grad = False
  572. if self.use_abs_pos_embed:
  573. self.absolute_pos_embed.requires_grad = False
  574. self.drop_after_pos.eval()
  575. for i in range(1, self.frozen_stages + 1):
  576. if (i - 1) in self.out_indices:
  577. norm_layer = getattr(self, f'norm{i-1}')
  578. norm_layer.eval()
  579. for param in norm_layer.parameters():
  580. param.requires_grad = False
  581. m = self.stages[i - 1]
  582. m.eval()
  583. for param in m.parameters():
  584. param.requires_grad = False
  585. def init_weights(self):
  586. logger = MMLogger.get_current_instance()
  587. if self.init_cfg is None:
  588. logger.warn(f'No pre-trained weights for '
  589. f'{self.__class__.__name__}, '
  590. f'training start from scratch')
  591. if self.use_abs_pos_embed:
  592. trunc_normal_(self.absolute_pos_embed, std=0.02)
  593. for m in self.modules():
  594. if isinstance(m, nn.Linear):
  595. trunc_normal_init(m, std=.02, bias=0.)
  596. elif isinstance(m, nn.LayerNorm):
  597. constant_init(m, 1.0)
  598. else:
  599. assert 'checkpoint' in self.init_cfg, f'Only support ' \
  600. f'specify `Pretrained` in ' \
  601. f'`init_cfg` in ' \
  602. f'{self.__class__.__name__} '
  603. ckpt = CheckpointLoader.load_checkpoint(
  604. self.init_cfg.checkpoint, logger=logger, map_location='cpu')
  605. if 'state_dict' in ckpt:
  606. _state_dict = ckpt['state_dict']
  607. elif 'model' in ckpt:
  608. _state_dict = ckpt['model']
  609. else:
  610. _state_dict = ckpt
  611. if self.convert_weights:
  612. # supported loading weight from original repo,
  613. _state_dict = swin_converter(_state_dict)
  614. state_dict = OrderedDict()
  615. for k, v in _state_dict.items():
  616. if k.startswith('backbone.'):
  617. state_dict[k[9:]] = v
  618. # strip prefix of state_dict
  619. if list(state_dict.keys())[0].startswith('module.'):
  620. state_dict = {k[7:]: v for k, v in state_dict.items()}
  621. # reshape absolute position embedding
  622. if state_dict.get('absolute_pos_embed') is not None:
  623. absolute_pos_embed = state_dict['absolute_pos_embed']
  624. N1, L, C1 = absolute_pos_embed.size()
  625. N2, C2, H, W = self.absolute_pos_embed.size()
  626. if N1 != N2 or C1 != C2 or L != H * W:
  627. logger.warning('Error in loading absolute_pos_embed, pass')
  628. else:
  629. state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
  630. N2, H, W, C2).permute(0, 3, 1, 2).contiguous()
  631. # interpolate position bias table if needed
  632. relative_position_bias_table_keys = [
  633. k for k in state_dict.keys()
  634. if 'relative_position_bias_table' in k
  635. ]
  636. for table_key in relative_position_bias_table_keys:
  637. table_pretrained = state_dict[table_key]
  638. table_current = self.state_dict()[table_key]
  639. L1, nH1 = table_pretrained.size()
  640. L2, nH2 = table_current.size()
  641. if nH1 != nH2:
  642. logger.warning(f'Error in loading {table_key}, pass')
  643. elif L1 != L2:
  644. S1 = int(L1**0.5)
  645. S2 = int(L2**0.5)
  646. table_pretrained_resized = F.interpolate(
  647. table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),
  648. size=(S2, S2),
  649. mode='bicubic')
  650. state_dict[table_key] = table_pretrained_resized.view(
  651. nH2, L2).permute(1, 0).contiguous()
  652. # load state_dict
  653. self.load_state_dict(state_dict, False)
  654. def forward(self, x):
  655. x, hw_shape = self.patch_embed(x)
  656. if self.use_abs_pos_embed:
  657. x = x + self.absolute_pos_embed
  658. x = self.drop_after_pos(x)
  659. outs = []
  660. for i, stage in enumerate(self.stages):
  661. x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
  662. if i in self.out_indices:
  663. norm_layer = getattr(self, f'norm{i}')
  664. out = norm_layer(out)
  665. out = out.view(-1, *out_hw_shape,
  666. self.num_features[i]).permute(0, 3, 1,
  667. 2).contiguous()
  668. outs.append(out)
  669. return outs
  670. def swin_converter(ckpt):
  671. new_ckpt = OrderedDict()
  672. def correct_unfold_reduction_order(x):
  673. out_channel, in_channel = x.shape
  674. x = x.reshape(out_channel, 4, in_channel // 4)
  675. x = x[:, [0, 2, 1, 3], :].transpose(1,
  676. 2).reshape(out_channel, in_channel)
  677. return x
  678. def correct_unfold_norm_order(x):
  679. in_channel = x.shape[0]
  680. x = x.reshape(4, in_channel // 4)
  681. x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
  682. return x
  683. for k, v in ckpt.items():
  684. if k.startswith('head'):
  685. continue
  686. elif k.startswith('layers'):
  687. new_v = v
  688. if 'attn.' in k:
  689. new_k = k.replace('attn.', 'attn.w_msa.')
  690. elif 'mlp.' in k:
  691. if 'mlp.fc1.' in k:
  692. new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
  693. elif 'mlp.fc2.' in k:
  694. new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
  695. else:
  696. new_k = k.replace('mlp.', 'ffn.')
  697. elif 'downsample' in k:
  698. new_k = k
  699. if 'reduction.' in k:
  700. new_v = correct_unfold_reduction_order(v)
  701. elif 'norm.' in k:
  702. new_v = correct_unfold_norm_order(v)
  703. else:
  704. new_k = k
  705. new_k = new_k.replace('layers', 'stages', 1)
  706. elif k.startswith('patch_embed'):
  707. new_v = v
  708. if 'proj' in k:
  709. new_k = k.replace('proj', 'projection')
  710. else:
  711. new_k = k
  712. else:
  713. new_v = v
  714. new_k = k
  715. new_ckpt['backbone.' + new_k] = new_v
  716. return new_ckpt