test_pvt.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import pytest
  2. import torch
  3. from mmdet.models.backbones.pvt import (PVTEncoderLayer,
  4. PyramidVisionTransformer,
  5. PyramidVisionTransformerV2)
  6. def test_pvt_block():
  7. # test PVT structure and forward
  8. block = PVTEncoderLayer(
  9. embed_dims=64, num_heads=4, feedforward_channels=256)
  10. assert block.ffn.embed_dims == 64
  11. assert block.attn.num_heads == 4
  12. assert block.ffn.feedforward_channels == 256
  13. x = torch.randn(1, 56 * 56, 64)
  14. x_out = block(x, (56, 56))
  15. assert x_out.shape == torch.Size([1, 56 * 56, 64])
  16. def test_pvt():
  17. """Test PVT backbone."""
  18. with pytest.raises(TypeError):
  19. # Pretrained arg must be str or None.
  20. PyramidVisionTransformer(pretrained=123)
  21. # test pretrained image size
  22. with pytest.raises(AssertionError):
  23. PyramidVisionTransformer(pretrain_img_size=(224, 224, 224))
  24. # Test absolute position embedding
  25. temp = torch.randn((1, 3, 224, 224))
  26. model = PyramidVisionTransformer(
  27. pretrain_img_size=224, use_abs_pos_embed=True)
  28. model.init_weights()
  29. model(temp)
  30. # Test normal inference
  31. temp = torch.randn((1, 3, 32, 32))
  32. model = PyramidVisionTransformer()
  33. outs = model(temp)
  34. assert outs[0].shape == (1, 64, 8, 8)
  35. assert outs[1].shape == (1, 128, 4, 4)
  36. assert outs[2].shape == (1, 320, 2, 2)
  37. assert outs[3].shape == (1, 512, 1, 1)
  38. # Test abnormal inference size
  39. temp = torch.randn((1, 3, 33, 33))
  40. model = PyramidVisionTransformer()
  41. outs = model(temp)
  42. assert outs[0].shape == (1, 64, 8, 8)
  43. assert outs[1].shape == (1, 128, 4, 4)
  44. assert outs[2].shape == (1, 320, 2, 2)
  45. assert outs[3].shape == (1, 512, 1, 1)
  46. # Test abnormal inference size
  47. temp = torch.randn((1, 3, 112, 137))
  48. model = PyramidVisionTransformer()
  49. outs = model(temp)
  50. assert outs[0].shape == (1, 64, 28, 34)
  51. assert outs[1].shape == (1, 128, 14, 17)
  52. assert outs[2].shape == (1, 320, 7, 8)
  53. assert outs[3].shape == (1, 512, 3, 4)
  54. def test_pvtv2():
  55. """Test PVTv2 backbone."""
  56. with pytest.raises(TypeError):
  57. # Pretrained arg must be str or None.
  58. PyramidVisionTransformerV2(pretrained=123)
  59. # test pretrained image size
  60. with pytest.raises(AssertionError):
  61. PyramidVisionTransformerV2(pretrain_img_size=(224, 224, 224))
  62. # Test normal inference
  63. temp = torch.randn((1, 3, 32, 32))
  64. model = PyramidVisionTransformerV2()
  65. outs = model(temp)
  66. assert outs[0].shape == (1, 64, 8, 8)
  67. assert outs[1].shape == (1, 128, 4, 4)
  68. assert outs[2].shape == (1, 320, 2, 2)
  69. assert outs[3].shape == (1, 512, 1, 1)
  70. # Test abnormal inference size
  71. temp = torch.randn((1, 3, 31, 31))
  72. model = PyramidVisionTransformerV2()
  73. outs = model(temp)
  74. assert outs[0].shape == (1, 64, 8, 8)
  75. assert outs[1].shape == (1, 128, 4, 4)
  76. assert outs[2].shape == (1, 320, 2, 2)
  77. assert outs[3].shape == (1, 512, 1, 1)
  78. # Test abnormal inference size
  79. temp = torch.randn((1, 3, 112, 137))
  80. model = PyramidVisionTransformerV2()
  81. outs = model(temp)
  82. assert outs[0].shape == (1, 64, 28, 35)
  83. assert outs[1].shape == (1, 128, 14, 18)
  84. assert outs[2].shape == (1, 320, 7, 9)
  85. assert outs[3].shape == (1, 512, 4, 5)