test_renext.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from mmdet.models.backbones import ResNeXt
  5. from mmdet.models.backbones.resnext import Bottleneck as BottleneckX
  6. from .utils import is_block
  7. def test_renext_bottleneck():
  8. with pytest.raises(AssertionError):
  9. # Style must be in ['pytorch', 'caffe']
  10. BottleneckX(64, 64, groups=32, base_width=4, style='tensorflow')
  11. # Test ResNeXt Bottleneck structure
  12. block = BottleneckX(
  13. 64, 64, groups=32, base_width=4, stride=2, style='pytorch')
  14. assert block.conv2.stride == (2, 2)
  15. assert block.conv2.groups == 32
  16. assert block.conv2.out_channels == 128
  17. # Test ResNeXt Bottleneck with DCN
  18. dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
  19. with pytest.raises(AssertionError):
  20. # conv_cfg must be None if dcn is not None
  21. BottleneckX(
  22. 64,
  23. 64,
  24. groups=32,
  25. base_width=4,
  26. dcn=dcn,
  27. conv_cfg=dict(type='Conv'))
  28. BottleneckX(64, 64, dcn=dcn)
  29. # Test ResNeXt Bottleneck forward
  30. block = BottleneckX(64, 16, groups=32, base_width=4)
  31. x = torch.randn(1, 64, 56, 56)
  32. x_out = block(x)
  33. assert x_out.shape == torch.Size([1, 64, 56, 56])
  34. # Test ResNeXt Bottleneck forward with plugins
  35. plugins = [
  36. dict(
  37. cfg=dict(
  38. type='GeneralizedAttention',
  39. spatial_range=-1,
  40. num_heads=8,
  41. attention_type='0010',
  42. kv_stride=2),
  43. stages=(False, False, True, True),
  44. position='after_conv2')
  45. ]
  46. block = BottleneckX(64, 16, groups=32, base_width=4, plugins=plugins)
  47. x = torch.randn(1, 64, 56, 56)
  48. x_out = block(x)
  49. assert x_out.shape == torch.Size([1, 64, 56, 56])
  50. def test_resnext_backbone():
  51. with pytest.raises(KeyError):
  52. # ResNeXt depth should be in [50, 101, 152]
  53. ResNeXt(depth=18)
  54. # Test ResNeXt with group 32, base_width 4
  55. model = ResNeXt(depth=50, groups=32, base_width=4)
  56. for m in model.modules():
  57. if is_block(m):
  58. assert m.conv2.groups == 32
  59. model.train()
  60. imgs = torch.randn(1, 3, 32, 32)
  61. feat = model(imgs)
  62. assert len(feat) == 4
  63. assert feat[0].shape == torch.Size([1, 256, 8, 8])
  64. assert feat[1].shape == torch.Size([1, 512, 4, 4])
  65. assert feat[2].shape == torch.Size([1, 1024, 2, 2])
  66. assert feat[3].shape == torch.Size([1, 2048, 1, 1])
  67. regnet_test_data = [
  68. ('regnetx_400mf',
  69. dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22,
  70. bot_mul=1.0), [32, 64, 160, 384]),
  71. ('regnetx_800mf',
  72. dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16,
  73. bot_mul=1.0), [64, 128, 288, 672]),
  74. ('regnetx_1.6gf',
  75. dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18,
  76. bot_mul=1.0), [72, 168, 408, 912]),
  77. ('regnetx_3.2gf',
  78. dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25,
  79. bot_mul=1.0), [96, 192, 432, 1008]),
  80. ('regnetx_4.0gf',
  81. dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23,
  82. bot_mul=1.0), [80, 240, 560, 1360]),
  83. ('regnetx_6.4gf',
  84. dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17,
  85. bot_mul=1.0), [168, 392, 784, 1624]),
  86. ('regnetx_8.0gf',
  87. dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23,
  88. bot_mul=1.0), [80, 240, 720, 1920]),
  89. ('regnetx_12gf',
  90. dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19,
  91. bot_mul=1.0), [224, 448, 896, 2240]),
  92. ]