test_hourglass.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmpose.models.backbones import HourglassAENet, HourglassNet
  5. class TestHourglass(TestCase):
  6. def test_hourglass_backbone(self):
  7. with self.assertRaises(AssertionError):
  8. # HourglassNet's num_stacks should larger than 0
  9. HourglassNet(num_stacks=0)
  10. with self.assertRaises(AssertionError):
  11. # len(stage_channels) should equal len(stage_blocks)
  12. HourglassNet(
  13. stage_channels=[256, 256, 384, 384, 384],
  14. stage_blocks=[2, 2, 2, 2, 2, 4])
  15. with self.assertRaises(AssertionError):
  16. # len(stage_channels) should larger than downsample_times
  17. HourglassNet(
  18. downsample_times=5,
  19. stage_channels=[256, 256, 384, 384, 384],
  20. stage_blocks=[2, 2, 2, 2, 2])
  21. # Test HourglassNet-52
  22. model = HourglassNet(num_stacks=1)
  23. model.init_weights()
  24. model.train()
  25. imgs = torch.randn(1, 3, 256, 256)
  26. feat = model(imgs)
  27. self.assertEqual(len(feat), 1)
  28. self.assertEqual(feat[0].shape, torch.Size([1, 256, 64, 64]))
  29. # Test HourglassNet-104
  30. model = HourglassNet(num_stacks=2)
  31. model.init_weights()
  32. model.train()
  33. imgs = torch.randn(1, 3, 256, 256)
  34. feat = model(imgs)
  35. self.assertEqual(len(feat), 2)
  36. self.assertEqual(feat[0].shape, torch.Size([1, 256, 64, 64]))
  37. self.assertEqual(feat[1].shape, torch.Size([1, 256, 64, 64]))
  38. def test_hourglass_ae_backbone(self):
  39. with self.assertRaises(AssertionError):
  40. # HourglassAENet's num_stacks should larger than 0
  41. HourglassAENet(num_stacks=0)
  42. with self.assertRaises(AssertionError):
  43. # len(stage_channels) should larger than downsample_times
  44. HourglassAENet(
  45. downsample_times=5, stage_channels=[256, 256, 384, 384, 384])
  46. # num_stack=1
  47. model = HourglassAENet(num_stacks=1)
  48. model.init_weights()
  49. model.train()
  50. imgs = torch.randn(1, 3, 256, 256)
  51. feat = model(imgs)
  52. self.assertEqual(len(feat), 1)
  53. self.assertEqual(feat[0].shape, torch.Size([1, 34, 64, 64]))
  54. # num_stack=2
  55. model = HourglassAENet(num_stacks=2)
  56. model.init_weights()
  57. model.train()
  58. imgs = torch.randn(1, 3, 256, 256)
  59. feat = model(imgs)
  60. self.assertEqual(len(feat), 2)
  61. self.assertEqual(feat[0].shape, torch.Size([1, 34, 64, 64]))
  62. self.assertEqual(feat[1].shape, torch.Size([1, 34, 64, 64]))