test_csp_darknet.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from torch.nn.modules.batchnorm import _BatchNorm
  5. from mmdet.models.backbones.csp_darknet import CSPDarknet
  6. from .utils import check_norm_state, is_norm
  7. def test_csp_darknet_backbone():
  8. with pytest.raises(ValueError):
  9. # frozen_stages must in range(-1, len(arch_setting) + 1)
  10. CSPDarknet(frozen_stages=6)
  11. with pytest.raises(AssertionError):
  12. # out_indices in range(len(arch_setting) + 1)
  13. CSPDarknet(out_indices=[6])
  14. # Test CSPDarknet with first stage frozen
  15. frozen_stages = 1
  16. model = CSPDarknet(frozen_stages=frozen_stages)
  17. model.train()
  18. for mod in model.stem.modules():
  19. for param in mod.parameters():
  20. assert param.requires_grad is False
  21. for i in range(1, frozen_stages + 1):
  22. layer = getattr(model, f'stage{i}')
  23. for mod in layer.modules():
  24. if isinstance(mod, _BatchNorm):
  25. assert mod.training is False
  26. for param in layer.parameters():
  27. assert param.requires_grad is False
  28. # Test CSPDarknet with norm_eval=True
  29. model = CSPDarknet(norm_eval=True)
  30. model.train()
  31. assert check_norm_state(model.modules(), False)
  32. # Test CSPDarknet-P5 forward with widen_factor=0.5
  33. model = CSPDarknet(arch='P5', widen_factor=0.25, out_indices=range(0, 5))
  34. model.train()
  35. imgs = torch.randn(1, 3, 64, 64)
  36. feat = model(imgs)
  37. assert len(feat) == 5
  38. assert feat[0].shape == torch.Size((1, 16, 32, 32))
  39. assert feat[1].shape == torch.Size((1, 32, 16, 16))
  40. assert feat[2].shape == torch.Size((1, 64, 8, 8))
  41. assert feat[3].shape == torch.Size((1, 128, 4, 4))
  42. assert feat[4].shape == torch.Size((1, 256, 2, 2))
  43. # Test CSPDarknet-P6 forward with widen_factor=0.5
  44. model = CSPDarknet(
  45. arch='P6',
  46. widen_factor=0.25,
  47. out_indices=range(0, 6),
  48. spp_kernal_sizes=(3, 5, 7))
  49. model.train()
  50. imgs = torch.randn(1, 3, 128, 128)
  51. feat = model(imgs)
  52. assert feat[0].shape == torch.Size((1, 16, 64, 64))
  53. assert feat[1].shape == torch.Size((1, 32, 32, 32))
  54. assert feat[2].shape == torch.Size((1, 64, 16, 16))
  55. assert feat[3].shape == torch.Size((1, 128, 8, 8))
  56. assert feat[4].shape == torch.Size((1, 192, 4, 4))
  57. assert feat[5].shape == torch.Size((1, 256, 2, 2))
  58. # Test CSPDarknet forward with dict(type='ReLU')
  59. model = CSPDarknet(
  60. widen_factor=0.125, act_cfg=dict(type='ReLU'), out_indices=range(0, 5))
  61. model.train()
  62. imgs = torch.randn(1, 3, 64, 64)
  63. feat = model(imgs)
  64. assert len(feat) == 5
  65. assert feat[0].shape == torch.Size((1, 8, 32, 32))
  66. assert feat[1].shape == torch.Size((1, 16, 16, 16))
  67. assert feat[2].shape == torch.Size((1, 32, 8, 8))
  68. assert feat[3].shape == torch.Size((1, 64, 4, 4))
  69. assert feat[4].shape == torch.Size((1, 128, 2, 2))
  70. # Test CSPDarknet with BatchNorm forward
  71. model = CSPDarknet(widen_factor=0.125, out_indices=range(0, 5))
  72. for m in model.modules():
  73. if is_norm(m):
  74. assert isinstance(m, _BatchNorm)
  75. model.train()
  76. imgs = torch.randn(1, 3, 64, 64)
  77. feat = model(imgs)
  78. assert len(feat) == 5
  79. assert feat[0].shape == torch.Size((1, 8, 32, 32))
  80. assert feat[1].shape == torch.Size((1, 16, 16, 16))
  81. assert feat[2].shape == torch.Size((1, 32, 8, 8))
  82. assert feat[3].shape == torch.Size((1, 64, 4, 4))
  83. assert feat[4].shape == torch.Size((1, 128, 2, 2))
  84. # Test CSPDarknet with custom arch forward
  85. arch_ovewrite = [[32, 56, 3, True, False], [56, 224, 2, True, False],
  86. [224, 512, 1, True, False]]
  87. model = CSPDarknet(
  88. arch_ovewrite=arch_ovewrite,
  89. widen_factor=0.25,
  90. out_indices=(0, 1, 2, 3))
  91. model.train()
  92. imgs = torch.randn(1, 3, 32, 32)
  93. feat = model(imgs)
  94. assert len(feat) == 4
  95. assert feat[0].shape == torch.Size((1, 8, 16, 16))
  96. assert feat[1].shape == torch.Size((1, 14, 8, 8))
  97. assert feat[2].shape == torch.Size((1, 56, 4, 4))
  98. assert feat[3].shape == torch.Size((1, 128, 2, 2))