test_position_encoding.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from mmdet.models.layers import (LearnedPositionalEncoding,
  5. SinePositionalEncoding)
  6. def test_sine_positional_encoding(num_feats=16, batch_size=2):
  7. # test invalid type of scale
  8. with pytest.raises(AssertionError):
  9. module = SinePositionalEncoding(
  10. num_feats, scale=(3., ), normalize=True)
  11. module = SinePositionalEncoding(num_feats)
  12. h, w = 10, 6
  13. mask = (torch.rand(batch_size, h, w) > 0.5).to(torch.int)
  14. assert not module.normalize
  15. out = module(mask)
  16. assert out.shape == (batch_size, num_feats * 2, h, w)
  17. # set normalize
  18. module = SinePositionalEncoding(num_feats, normalize=True)
  19. assert module.normalize
  20. out = module(mask)
  21. assert out.shape == (batch_size, num_feats * 2, h, w)
  22. def test_learned_positional_encoding(num_feats=16,
  23. row_num_embed=10,
  24. col_num_embed=10,
  25. batch_size=2):
  26. module = LearnedPositionalEncoding(num_feats, row_num_embed, col_num_embed)
  27. assert module.row_embed.weight.shape == (row_num_embed, num_feats)
  28. assert module.col_embed.weight.shape == (col_num_embed, num_feats)
  29. h, w = 10, 6
  30. mask = torch.rand(batch_size, h, w) > 0.5
  31. out = module(mask)
  32. assert out.shape == (batch_size, num_feats * 2, h, w)