test_brick_wrappers.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from unittest.mock import patch
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from mmdet.models.layers import AdaptiveAvgPool2d, adaptive_avg_pool2d
  6. if torch.__version__ != 'parrots':
  7. torch_version = '1.7'
  8. else:
  9. torch_version = 'parrots'
  10. @patch('torch.__version__', torch_version)
  11. def test_adaptive_avg_pool2d():
  12. # Test the empty batch dimension
  13. # Test the two input conditions
  14. x_empty = torch.randn(0, 3, 4, 5)
  15. # 1. tuple[int, int]
  16. wrapper_out = adaptive_avg_pool2d(x_empty, (2, 2))
  17. assert wrapper_out.shape == (0, 3, 2, 2)
  18. # 2. int
  19. wrapper_out = adaptive_avg_pool2d(x_empty, 2)
  20. assert wrapper_out.shape == (0, 3, 2, 2)
  21. # wrapper op with 3-dim input
  22. x_normal = torch.randn(3, 3, 4, 5)
  23. wrapper_out = adaptive_avg_pool2d(x_normal, (2, 2))
  24. ref_out = F.adaptive_avg_pool2d(x_normal, (2, 2))
  25. assert wrapper_out.shape == (3, 3, 2, 2)
  26. assert torch.equal(wrapper_out, ref_out)
  27. wrapper_out = adaptive_avg_pool2d(x_normal, 2)
  28. ref_out = F.adaptive_avg_pool2d(x_normal, 2)
  29. assert wrapper_out.shape == (3, 3, 2, 2)
  30. assert torch.equal(wrapper_out, ref_out)
  31. @patch('torch.__version__', torch_version)
  32. def test_AdaptiveAvgPool2d():
  33. # Test the empty batch dimension
  34. x_empty = torch.randn(0, 3, 4, 5)
  35. # Test the four input conditions
  36. # 1. tuple[int, int]
  37. wrapper = AdaptiveAvgPool2d((2, 2))
  38. wrapper_out = wrapper(x_empty)
  39. assert wrapper_out.shape == (0, 3, 2, 2)
  40. # 2. int
  41. wrapper = AdaptiveAvgPool2d(2)
  42. wrapper_out = wrapper(x_empty)
  43. assert wrapper_out.shape == (0, 3, 2, 2)
  44. # 3. tuple[None, int]
  45. wrapper = AdaptiveAvgPool2d((None, 2))
  46. wrapper_out = wrapper(x_empty)
  47. assert wrapper_out.shape == (0, 3, 4, 2)
  48. # 3. tuple[int, None]
  49. wrapper = AdaptiveAvgPool2d((2, None))
  50. wrapper_out = wrapper(x_empty)
  51. assert wrapper_out.shape == (0, 3, 2, 5)
  52. # Test the normal batch dimension
  53. x_normal = torch.randn(3, 3, 4, 5)
  54. wrapper = AdaptiveAvgPool2d((2, 2))
  55. ref = nn.AdaptiveAvgPool2d((2, 2))
  56. wrapper_out = wrapper(x_normal)
  57. ref_out = ref(x_normal)
  58. assert wrapper_out.shape == (3, 3, 2, 2)
  59. assert torch.equal(wrapper_out, ref_out)
  60. wrapper = AdaptiveAvgPool2d(2)
  61. ref = nn.AdaptiveAvgPool2d(2)
  62. wrapper_out = wrapper(x_normal)
  63. ref_out = ref(x_normal)
  64. assert wrapper_out.shape == (3, 3, 2, 2)
  65. assert torch.equal(wrapper_out, ref_out)
  66. wrapper = AdaptiveAvgPool2d((None, 2))
  67. ref = nn.AdaptiveAvgPool2d((None, 2))
  68. wrapper_out = wrapper(x_normal)
  69. ref_out = ref(x_normal)
  70. assert wrapper_out.shape == (3, 3, 4, 2)
  71. assert torch.equal(wrapper_out, ref_out)
  72. wrapper = AdaptiveAvgPool2d((2, None))
  73. ref = nn.AdaptiveAvgPool2d((2, None))
  74. wrapper_out = wrapper(x_normal)
  75. ref_out = ref(x_normal)
  76. assert wrapper_out.shape == (3, 3, 2, 5)
  77. assert torch.equal(wrapper_out, ref_out)