test_hourglass.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from mmdet.models.backbones.hourglass import HourglassNet
  5. def test_hourglass_backbone():
  6. with pytest.raises(AssertionError):
  7. # HourglassNet's num_stacks should larger than 0
  8. HourglassNet(num_stacks=0)
  9. with pytest.raises(AssertionError):
  10. # len(stage_channels) should equal len(stage_blocks)
  11. HourglassNet(
  12. stage_channels=[256, 256, 384, 384, 384],
  13. stage_blocks=[2, 2, 2, 2, 2, 4])
  14. with pytest.raises(AssertionError):
  15. # len(stage_channels) should lagrer than downsample_times
  16. HourglassNet(
  17. downsample_times=5,
  18. stage_channels=[256, 256, 384, 384, 384],
  19. stage_blocks=[2, 2, 2, 2, 2])
  20. # Test HourglassNet-52
  21. model = HourglassNet(
  22. num_stacks=1,
  23. stage_channels=(64, 64, 96, 96, 96, 128),
  24. feat_channel=64)
  25. model.train()
  26. imgs = torch.randn(1, 3, 256, 256)
  27. feat = model(imgs)
  28. assert len(feat) == 1
  29. assert feat[0].shape == torch.Size([1, 64, 64, 64])
  30. # Test HourglassNet-104
  31. model = HourglassNet(
  32. num_stacks=2,
  33. stage_channels=(64, 64, 96, 96, 96, 128),
  34. feat_channel=64)
  35. model.train()
  36. imgs = torch.randn(1, 3, 256, 256)
  37. feat = model(imgs)
  38. assert len(feat) == 2
  39. assert feat[0].shape == torch.Size([1, 64, 64, 64])
  40. assert feat[1].shape == torch.Size([1, 64, 64, 64])