test_regnet.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from mmdet.models.backbones import RegNet
  5. regnet_test_data = [
  6. ('regnetx_400mf',
  7. dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22,
  8. bot_mul=1.0), [32, 64, 160, 384]),
  9. ('regnetx_800mf',
  10. dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16,
  11. bot_mul=1.0), [64, 128, 288, 672]),
  12. ('regnetx_1.6gf',
  13. dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18,
  14. bot_mul=1.0), [72, 168, 408, 912]),
  15. ('regnetx_3.2gf',
  16. dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25,
  17. bot_mul=1.0), [96, 192, 432, 1008]),
  18. ('regnetx_4.0gf',
  19. dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23,
  20. bot_mul=1.0), [80, 240, 560, 1360]),
  21. ('regnetx_6.4gf',
  22. dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17,
  23. bot_mul=1.0), [168, 392, 784, 1624]),
  24. ('regnetx_8.0gf',
  25. dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23,
  26. bot_mul=1.0), [80, 240, 720, 1920]),
  27. ('regnetx_12gf',
  28. dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19,
  29. bot_mul=1.0), [224, 448, 896, 2240]),
  30. ]
  31. @pytest.mark.parametrize('arch_name,arch,out_channels', regnet_test_data)
  32. def test_regnet_backbone(arch_name, arch, out_channels):
  33. with pytest.raises(AssertionError):
  34. # ResNeXt depth should be in [50, 101, 152]
  35. RegNet(arch_name + '233')
  36. # Test RegNet with arch_name
  37. model = RegNet(arch_name)
  38. model.train()
  39. imgs = torch.randn(1, 3, 32, 32)
  40. feat = model(imgs)
  41. assert len(feat) == 4
  42. assert feat[0].shape == torch.Size([1, out_channels[0], 8, 8])
  43. assert feat[1].shape == torch.Size([1, out_channels[1], 4, 4])
  44. assert feat[2].shape == torch.Size([1, out_channels[2], 2, 2])
  45. assert feat[3].shape == torch.Size([1, out_channels[3], 1, 1])
  46. # Test RegNet with arch
  47. model = RegNet(arch)
  48. assert feat[0].shape == torch.Size([1, out_channels[0], 8, 8])
  49. assert feat[1].shape == torch.Size([1, out_channels[1], 4, 4])
  50. assert feat[2].shape == torch.Size([1, out_channels[2], 2, 2])
  51. assert feat[3].shape == torch.Size([1, out_channels[3], 1, 1])