12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import pytest
- import torch
- from mmdet.models.backbones.hourglass import HourglassNet
- def test_hourglass_backbone():
- with pytest.raises(AssertionError):
- # HourglassNet's num_stacks should larger than 0
- HourglassNet(num_stacks=0)
- with pytest.raises(AssertionError):
- # len(stage_channels) should equal len(stage_blocks)
- HourglassNet(
- stage_channels=[256, 256, 384, 384, 384],
- stage_blocks=[2, 2, 2, 2, 2, 4])
- with pytest.raises(AssertionError):
- # len(stage_channels) should lagrer than downsample_times
- HourglassNet(
- downsample_times=5,
- stage_channels=[256, 256, 384, 384, 384],
- stage_blocks=[2, 2, 2, 2, 2])
- # Test HourglassNet-52
- model = HourglassNet(
- num_stacks=1,
- stage_channels=(64, 64, 96, 96, 96, 128),
- feat_channel=64)
- model.train()
- imgs = torch.randn(1, 3, 256, 256)
- feat = model(imgs)
- assert len(feat) == 1
- assert feat[0].shape == torch.Size([1, 64, 64, 64])
- # Test HourglassNet-104
- model = HourglassNet(
- num_stacks=2,
- stage_channels=(64, 64, 96, 96, 96, 128),
- feat_channel=64)
- model.train()
- imgs = torch.randn(1, 3, 256, 256)
- feat = model(imgs)
- assert len(feat) == 2
- assert feat[0].shape == torch.Size([1, 64, 64, 64])
- assert feat[1].shape == torch.Size([1, 64, 64, 64])
|