12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from unittest import TestCase
- import torch
- from mmpose.models.backbones import HourglassAENet, HourglassNet
- class TestHourglass(TestCase):
- def test_hourglass_backbone(self):
- with self.assertRaises(AssertionError):
- # HourglassNet's num_stacks should larger than 0
- HourglassNet(num_stacks=0)
- with self.assertRaises(AssertionError):
- # len(stage_channels) should equal len(stage_blocks)
- HourglassNet(
- stage_channels=[256, 256, 384, 384, 384],
- stage_blocks=[2, 2, 2, 2, 2, 4])
- with self.assertRaises(AssertionError):
- # len(stage_channels) should larger than downsample_times
- HourglassNet(
- downsample_times=5,
- stage_channels=[256, 256, 384, 384, 384],
- stage_blocks=[2, 2, 2, 2, 2])
- # Test HourglassNet-52
- model = HourglassNet(num_stacks=1)
- model.init_weights()
- model.train()
- imgs = torch.randn(1, 3, 256, 256)
- feat = model(imgs)
- self.assertEqual(len(feat), 1)
- self.assertEqual(feat[0].shape, torch.Size([1, 256, 64, 64]))
- # Test HourglassNet-104
- model = HourglassNet(num_stacks=2)
- model.init_weights()
- model.train()
- imgs = torch.randn(1, 3, 256, 256)
- feat = model(imgs)
- self.assertEqual(len(feat), 2)
- self.assertEqual(feat[0].shape, torch.Size([1, 256, 64, 64]))
- self.assertEqual(feat[1].shape, torch.Size([1, 256, 64, 64]))
- def test_hourglass_ae_backbone(self):
- with self.assertRaises(AssertionError):
- # HourglassAENet's num_stacks should larger than 0
- HourglassAENet(num_stacks=0)
- with self.assertRaises(AssertionError):
- # len(stage_channels) should larger than downsample_times
- HourglassAENet(
- downsample_times=5, stage_channels=[256, 256, 384, 384, 384])
- # num_stack=1
- model = HourglassAENet(num_stacks=1)
- model.init_weights()
- model.train()
- imgs = torch.randn(1, 3, 256, 256)
- feat = model(imgs)
- self.assertEqual(len(feat), 1)
- self.assertEqual(feat[0].shape, torch.Size([1, 34, 64, 64]))
- # num_stack=2
- model = HourglassAENet(num_stacks=2)
- model.init_weights()
- model.train()
- imgs = torch.randn(1, 3, 256, 256)
- feat = model(imgs)
- self.assertEqual(len(feat), 2)
- self.assertEqual(feat[0].shape, torch.Size([1, 34, 64, 64]))
- self.assertEqual(feat[1].shape, torch.Size([1, 34, 64, 64]))
|