test_inverted_residual.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from mmcv.cnn import is_norm
  5. from torch.nn.modules import GroupNorm
  6. from mmdet.models.layers import InvertedResidual, SELayer
  7. def test_inverted_residual():
  8. with pytest.raises(AssertionError):
  9. # stride must be in [1, 2]
  10. InvertedResidual(16, 16, 32, stride=3)
  11. with pytest.raises(AssertionError):
  12. # se_cfg must be None or dict
  13. InvertedResidual(16, 16, 32, se_cfg=list())
  14. with pytest.raises(AssertionError):
  15. # in_channeld and mid_channels must be the same if
  16. # with_expand_conv is False
  17. InvertedResidual(16, 16, 32, with_expand_conv=False)
  18. # Test InvertedResidual forward, stride=1
  19. block = InvertedResidual(16, 16, 32, stride=1)
  20. x = torch.randn(1, 16, 56, 56)
  21. x_out = block(x)
  22. assert getattr(block, 'se', None) is None
  23. assert block.with_res_shortcut
  24. assert x_out.shape == torch.Size((1, 16, 56, 56))
  25. # Test InvertedResidual forward, stride=2
  26. block = InvertedResidual(16, 16, 32, stride=2)
  27. x = torch.randn(1, 16, 56, 56)
  28. x_out = block(x)
  29. assert not block.with_res_shortcut
  30. assert x_out.shape == torch.Size((1, 16, 28, 28))
  31. # Test InvertedResidual forward with se layer
  32. se_cfg = dict(channels=32)
  33. block = InvertedResidual(16, 16, 32, stride=1, se_cfg=se_cfg)
  34. x = torch.randn(1, 16, 56, 56)
  35. x_out = block(x)
  36. assert isinstance(block.se, SELayer)
  37. assert x_out.shape == torch.Size((1, 16, 56, 56))
  38. # Test InvertedResidual forward, with_expand_conv=False
  39. block = InvertedResidual(32, 16, 32, with_expand_conv=False)
  40. x = torch.randn(1, 32, 56, 56)
  41. x_out = block(x)
  42. assert getattr(block, 'expand_conv', None) is None
  43. assert x_out.shape == torch.Size((1, 16, 56, 56))
  44. # Test InvertedResidual forward with GroupNorm
  45. block = InvertedResidual(
  46. 16, 16, 32, norm_cfg=dict(type='GN', num_groups=2))
  47. x = torch.randn(1, 16, 56, 56)
  48. x_out = block(x)
  49. for m in block.modules():
  50. if is_norm(m):
  51. assert isinstance(m, GroupNorm)
  52. assert x_out.shape == torch.Size((1, 16, 56, 56))
  53. # Test InvertedResidual forward with HSigmoid
  54. block = InvertedResidual(16, 16, 32, act_cfg=dict(type='HSigmoid'))
  55. x = torch.randn(1, 16, 56, 56)
  56. x_out = block(x)
  57. assert x_out.shape == torch.Size((1, 16, 56, 56))
  58. # Test InvertedResidual forward with checkpoint
  59. block = InvertedResidual(16, 16, 32, with_cp=True)
  60. x = torch.randn(1, 16, 56, 56)
  61. x_out = block(x)
  62. assert block.with_cp
  63. assert x_out.shape == torch.Size((1, 16, 56, 56))