test_res2net.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from mmdet.models.backbones import Res2Net
  5. from mmdet.models.backbones.res2net import Bottle2neck
  6. from .utils import is_block
  7. def test_res2net_bottle2neck():
  8. with pytest.raises(AssertionError):
  9. # Style must be in ['pytorch', 'caffe']
  10. Bottle2neck(64, 64, base_width=26, scales=4, style='tensorflow')
  11. with pytest.raises(AssertionError):
  12. # Scale must be larger than 1
  13. Bottle2neck(64, 64, base_width=26, scales=1, style='pytorch')
  14. # Test Res2Net Bottle2neck structure
  15. block = Bottle2neck(
  16. 64, 64, base_width=26, stride=2, scales=4, style='pytorch')
  17. assert block.scales == 4
  18. # Test Res2Net Bottle2neck with DCN
  19. dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
  20. with pytest.raises(AssertionError):
  21. # conv_cfg must be None if dcn is not None
  22. Bottle2neck(
  23. 64,
  24. 64,
  25. base_width=26,
  26. scales=4,
  27. dcn=dcn,
  28. conv_cfg=dict(type='Conv'))
  29. Bottle2neck(64, 64, dcn=dcn)
  30. # Test Res2Net Bottle2neck forward
  31. block = Bottle2neck(64, 16, base_width=26, scales=4)
  32. x = torch.randn(1, 64, 56, 56)
  33. x_out = block(x)
  34. assert x_out.shape == torch.Size([1, 64, 56, 56])
  35. def test_res2net_backbone():
  36. with pytest.raises(KeyError):
  37. # Res2Net depth should be in [50, 101, 152]
  38. Res2Net(depth=18)
  39. # Test Res2Net with scales 4, base_width 26
  40. model = Res2Net(depth=50, scales=4, base_width=26)
  41. for m in model.modules():
  42. if is_block(m):
  43. assert m.scales == 4
  44. model.train()
  45. imgs = torch.randn(1, 3, 32, 32)
  46. feat = model(imgs)
  47. assert len(feat) == 4
  48. assert feat[0].shape == torch.Size([1, 256, 8, 8])
  49. assert feat[1].shape == torch.Size([1, 512, 4, 4])
  50. assert feat[2].shape == torch.Size([1, 1024, 2, 2])
  51. assert feat[3].shape == torch.Size([1, 2048, 1, 1])