test_pvt.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmpose.models.backbones.pvt import (PVTEncoderLayer,
  5. PyramidVisionTransformer,
  6. PyramidVisionTransformerV2)
  7. class TestPVT(TestCase):
  8. def test_pvt_block(self):
  9. # test PVT structure and forward
  10. block = PVTEncoderLayer(
  11. embed_dims=64, num_heads=4, feedforward_channels=256)
  12. self.assertEqual(block.ffn.embed_dims, 64)
  13. self.assertEqual(block.attn.num_heads, 4)
  14. self.assertEqual(block.ffn.feedforward_channels, 256)
  15. x = torch.randn(1, 56 * 56, 64)
  16. x_out = block(x, (56, 56))
  17. self.assertEqual(x_out.shape, torch.Size([1, 56 * 56, 64]))
  18. def test_pvt(self):
  19. """Test PVT backbone."""
  20. # test pretrained image size
  21. with self.assertRaises(AssertionError):
  22. PyramidVisionTransformer(pretrain_img_size=(224, 224, 224))
  23. # test padding
  24. model = PyramidVisionTransformer(
  25. paddings=['corner', 'corner', 'corner', 'corner'])
  26. temp = torch.randn((1, 3, 32, 32))
  27. outs = model(temp)
  28. self.assertEqual(outs[0].shape, (1, 64, 8, 8))
  29. self.assertEqual(outs[1].shape, (1, 128, 4, 4))
  30. self.assertEqual(outs[2].shape, (1, 320, 2, 2))
  31. self.assertEqual(outs[3].shape, (1, 512, 1, 1))
  32. # Test absolute position embedding
  33. temp = torch.randn((1, 3, 224, 224))
  34. model = PyramidVisionTransformer(
  35. pretrain_img_size=224, use_abs_pos_embed=True)
  36. model.init_weights()
  37. model(temp)
  38. # Test normal inference
  39. temp = torch.randn((1, 3, 32, 32))
  40. model = PyramidVisionTransformer()
  41. outs = model(temp)
  42. self.assertEqual(outs[0].shape, (1, 64, 8, 8))
  43. self.assertEqual(outs[1].shape, (1, 128, 4, 4))
  44. self.assertEqual(outs[2].shape, (1, 320, 2, 2))
  45. self.assertEqual(outs[3].shape, (1, 512, 1, 1))
  46. # Test abnormal inference size
  47. temp = torch.randn((1, 3, 33, 33))
  48. model = PyramidVisionTransformer()
  49. outs = model(temp)
  50. self.assertEqual(outs[0].shape, (1, 64, 8, 8))
  51. self.assertEqual(outs[1].shape, (1, 128, 4, 4))
  52. self.assertEqual(outs[2].shape, (1, 320, 2, 2))
  53. self.assertEqual(outs[3].shape, (1, 512, 1, 1))
  54. # Test abnormal inference size
  55. temp = torch.randn((1, 3, 112, 137))
  56. model = PyramidVisionTransformer()
  57. outs = model(temp)
  58. self.assertEqual(outs[0].shape, (1, 64, 28, 34))
  59. self.assertEqual(outs[1].shape, (1, 128, 14, 17))
  60. self.assertEqual(outs[2].shape, (1, 320, 7, 8))
  61. self.assertEqual(outs[3].shape, (1, 512, 3, 4))
  62. def test_pvtv2(self):
  63. """Test PVTv2 backbone."""
  64. with self.assertRaises(TypeError):
  65. # Pretrained arg must be str or None.
  66. PyramidVisionTransformerV2(pretrained=123)
  67. # test pretrained image size
  68. with self.assertRaises(AssertionError):
  69. PyramidVisionTransformerV2(pretrain_img_size=(224, 224, 224))
  70. # test init_cfg with pretrained model
  71. model = PyramidVisionTransformerV2(
  72. embed_dims=32,
  73. num_layers=[2, 2, 2, 2],
  74. init_cfg=dict(
  75. type='Pretrained',
  76. checkpoint='https://github.com/whai362/PVT/'
  77. 'releases/download/v2/pvt_v2_b0.pth'))
  78. model.init_weights()
  79. # test init weights from scratch
  80. model = PyramidVisionTransformerV2(
  81. embed_dims=32, num_layers=[2, 2, 2, 2])
  82. model.init_weights()
  83. # Test normal inference
  84. temp = torch.randn((1, 3, 32, 32))
  85. model = PyramidVisionTransformerV2()
  86. outs = model(temp)
  87. self.assertEqual(outs[0].shape, (1, 64, 8, 8))
  88. self.assertEqual(outs[1].shape, (1, 128, 4, 4))
  89. self.assertEqual(outs[2].shape, (1, 320, 2, 2))
  90. self.assertEqual(outs[3].shape, (1, 512, 1, 1))
  91. # Test abnormal inference size
  92. temp = torch.randn((1, 3, 31, 31))
  93. model = PyramidVisionTransformerV2()
  94. outs = model(temp)
  95. self.assertEqual(outs[0].shape, (1, 64, 8, 8))
  96. self.assertEqual(outs[1].shape, (1, 128, 4, 4))
  97. self.assertEqual(outs[2].shape, (1, 320, 2, 2))
  98. self.assertEqual(outs[3].shape, (1, 512, 1, 1))
  99. # Test abnormal inference size
  100. temp = torch.randn((1, 3, 112, 137))
  101. model = PyramidVisionTransformerV2()
  102. outs = model(temp)
  103. self.assertEqual(outs[0].shape, (1, 64, 28, 35))
  104. self.assertEqual(outs[1].shape, (1, 128, 14, 18))
  105. self.assertEqual(outs[2].shape, (1, 320, 7, 9))
  106. self.assertEqual(outs[3].shape, (1, 512, 4, 5))