test_swin.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmpose.models.backbones.swin import SwinBlock, SwinTransformer
  5. class TestSwin(TestCase):
  6. def test_swin_block(self):
  7. # test SwinBlock structure and forward
  8. block = SwinBlock(embed_dims=64, num_heads=4, feedforward_channels=256)
  9. self.assertEqual(block.ffn.embed_dims, 64)
  10. self.assertEqual(block.attn.w_msa.num_heads, 4)
  11. self.assertEqual(block.ffn.feedforward_channels, 256)
  12. x = torch.randn(1, 56 * 56, 64)
  13. x_out = block(x, (56, 56))
  14. self.assertEqual(x_out.shape, torch.Size([1, 56 * 56, 64]))
  15. # Test BasicBlock with checkpoint forward
  16. block = SwinBlock(
  17. embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True)
  18. self.assertTrue(block.with_cp)
  19. x = torch.randn(1, 56 * 56, 64)
  20. x_out = block(x, (56, 56))
  21. self.assertEqual(x_out.shape, torch.Size([1, 56 * 56, 64]))
  22. def test_swin_transformer(self):
  23. """Test Swin Transformer backbone."""
  24. with self.assertRaises(AssertionError):
  25. # Because swin uses non-overlapping patch embed, so the stride of
  26. # patch embed must be equal to patch size.
  27. SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)
  28. # test pretrained image size
  29. with self.assertRaises(AssertionError):
  30. SwinTransformer(pretrain_img_size=(224, 224, 224))
  31. # Test absolute position embedding
  32. temp = torch.randn((1, 3, 224, 224))
  33. model = SwinTransformer(pretrain_img_size=224, use_abs_pos_embed=True)
  34. model.init_weights()
  35. model(temp)
  36. # Test patch norm
  37. model = SwinTransformer(patch_norm=False)
  38. model(temp)
  39. # Test normal inference
  40. temp = torch.randn((1, 3, 32, 32))
  41. model = SwinTransformer()
  42. outs = model(temp)
  43. self.assertEqual(outs[0].shape, (1, 96, 8, 8))
  44. self.assertEqual(outs[1].shape, (1, 192, 4, 4))
  45. self.assertEqual(outs[2].shape, (1, 384, 2, 2))
  46. self.assertEqual(outs[3].shape, (1, 768, 1, 1))
  47. # Test abnormal inference size
  48. temp = torch.randn((1, 3, 31, 31))
  49. model = SwinTransformer()
  50. outs = model(temp)
  51. self.assertEqual(outs[0].shape, (1, 96, 8, 8))
  52. self.assertEqual(outs[1].shape, (1, 192, 4, 4))
  53. self.assertEqual(outs[2].shape, (1, 384, 2, 2))
  54. self.assertEqual(outs[3].shape, (1, 768, 1, 1))
  55. # Test abnormal inference size
  56. temp = torch.randn((1, 3, 112, 137))
  57. model = SwinTransformer()
  58. outs = model(temp)
  59. self.assertEqual(outs[0].shape, (1, 96, 28, 35))
  60. self.assertEqual(outs[1].shape, (1, 192, 14, 18))
  61. self.assertEqual(outs[2].shape, (1, 384, 7, 9))
  62. self.assertEqual(outs[3].shape, (1, 768, 4, 5))
  63. model = SwinTransformer(frozen_stages=4)
  64. model.train()
  65. for p in model.parameters():
  66. self.assertFalse(p.requires_grad)