1234567891011121314151617181920212223242526272829303132333435363738 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from unittest import TestCase
- import torch
- from mmpose.models.backbones import RSN
- class TestRSN(TestCase):
- def test_rsn_backbone(self):
- with self.assertRaises(AssertionError):
- # RSN's num_stages should larger than 0
- RSN(num_stages=0)
- with self.assertRaises(AssertionError):
- # RSN's num_steps should larger than 1
- RSN(num_steps=1)
- with self.assertRaises(AssertionError):
- # RSN's num_units should larger than 1
- RSN(num_units=1)
- with self.assertRaises(AssertionError):
- # len(num_blocks) should equal num_units
- RSN(num_units=2, num_blocks=[2, 2, 2])
- # Test RSN's outputs
- model = RSN(num_stages=2, num_units=2, num_blocks=[2, 2])
- model.init_weights()
- model.train()
- imgs = torch.randn(1, 3, 511, 511)
- feat = model(imgs)
- self.assertEqual(len(feat), 2)
- self.assertEqual(len(feat[0]), 2)
- self.assertEqual(len(feat[1]), 2)
- self.assertEqual(feat[0][0].shape, torch.Size([1, 256, 64, 64]))
- self.assertEqual(feat[0][1].shape, torch.Size([1, 256, 128, 128]))
- self.assertEqual(feat[1][0].shape, torch.Size([1, 256, 64, 64]))
- self.assertEqual(feat[1][1].shape, torch.Size([1, 256, 128, 128]))
|