test_backbones_utils.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from torch.nn.modules import GroupNorm
  5. from torch.nn.modules.batchnorm import _BatchNorm
  6. from mmpose.models.backbones.utils import (InvertedResidual, SELayer,
  7. channel_shuffle, make_divisible)
  8. class TestBackboneUtils(TestCase):
  9. @staticmethod
  10. def is_norm(modules):
  11. """Check if is one of the norms."""
  12. if isinstance(modules, (GroupNorm, _BatchNorm)):
  13. return True
  14. return False
  15. def test_make_divisible(self):
  16. # test min_value is None
  17. result = make_divisible(34, 8, None)
  18. self.assertEqual(result, 32)
  19. # test when new_value > min_ratio * value
  20. result = make_divisible(10, 8, min_ratio=0.9)
  21. self.assertEqual(result, 16)
  22. # test min_value = 0.8
  23. result = make_divisible(33, 8, min_ratio=0.8)
  24. self.assertEqual(result, 32)
  25. def test_channel_shuffle(self):
  26. x = torch.randn(1, 24, 56, 56)
  27. with self.assertRaisesRegex(
  28. AssertionError, 'num_channels should be divisible by groups'):
  29. channel_shuffle(x, 7)
  30. groups = 3
  31. batch_size, num_channels, height, width = x.size()
  32. channels_per_group = num_channels // groups
  33. out = channel_shuffle(x, groups)
  34. # test the output value when groups = 3
  35. for b in range(batch_size):
  36. for c in range(num_channels):
  37. c_out = c % channels_per_group * groups + \
  38. c // channels_per_group
  39. for i in range(height):
  40. for j in range(width):
  41. self.assertEqual(x[b, c, i, j], out[b, c_out, i, j])
  42. def test_inverted_residual(self):
  43. with self.assertRaises(AssertionError):
  44. # stride must be in [1, 2]
  45. InvertedResidual(16, 16, 32, stride=3)
  46. with self.assertRaises(AssertionError):
  47. # se_cfg must be None or dict
  48. InvertedResidual(16, 16, 32, se_cfg=list())
  49. with self.assertRaises(AssertionError):
  50. # in_channeld and out_channels must be the same if
  51. # with_expand_conv is False
  52. InvertedResidual(16, 16, 32, with_expand_conv=False)
  53. # Test InvertedResidual forward, stride=1
  54. block = InvertedResidual(16, 16, 32, stride=1)
  55. x = torch.randn(1, 16, 56, 56)
  56. x_out = block(x)
  57. self.assertIsNone(getattr(block, 'se', None))
  58. self.assertTrue(block.with_res_shortcut)
  59. self.assertEqual(x_out.shape, torch.Size((1, 16, 56, 56)))
  60. # Test InvertedResidual forward, stride=2
  61. block = InvertedResidual(16, 16, 32, stride=2)
  62. x = torch.randn(1, 16, 56, 56)
  63. x_out = block(x)
  64. self.assertFalse(block.with_res_shortcut)
  65. self.assertEqual(x_out.shape, torch.Size((1, 16, 28, 28)))
  66. # Test InvertedResidual forward with se layer
  67. se_cfg = dict(channels=32)
  68. block = InvertedResidual(16, 16, 32, stride=1, se_cfg=se_cfg)
  69. x = torch.randn(1, 16, 56, 56)
  70. x_out = block(x)
  71. self.assertIsInstance(block.se, SELayer)
  72. self.assertEqual(x_out.shape, torch.Size((1, 16, 56, 56)))
  73. # Test InvertedResidual forward, with_expand_conv=False
  74. block = InvertedResidual(32, 16, 32, with_expand_conv=False)
  75. x = torch.randn(1, 32, 56, 56)
  76. x_out = block(x)
  77. self.assertIsNone(getattr(block, 'expand_conv', None))
  78. self.assertEqual(x_out.shape, torch.Size((1, 16, 56, 56)))
  79. # Test InvertedResidual forward with GroupNorm
  80. block = InvertedResidual(
  81. 16, 16, 32, norm_cfg=dict(type='GN', num_groups=2))
  82. x = torch.randn(1, 16, 56, 56)
  83. x_out = block(x)
  84. for m in block.modules():
  85. if self.is_norm(m):
  86. self.assertIsInstance(m, GroupNorm)
  87. self.assertEqual(x_out.shape, torch.Size((1, 16, 56, 56)))
  88. # Test InvertedResidual forward with HSigmoid
  89. block = InvertedResidual(16, 16, 32, act_cfg=dict(type='HSigmoid'))
  90. x = torch.randn(1, 16, 56, 56)
  91. x_out = block(x)
  92. self.assertEqual(x_out.shape, torch.Size((1, 16, 56, 56)))
  93. # Test InvertedResidual forward with checkpoint
  94. block = InvertedResidual(16, 16, 32, with_cp=True)
  95. x = torch.randn(1, 16, 56, 56)
  96. x_out = block(x)
  97. self.assertTrue(block.with_cp)
  98. self.assertEqual(x_out.shape, torch.Size((1, 16, 56, 56)))