1234567891011121314151617 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from unittest import TestCase
- import torch
- from mmpose.models.backbones import V2VNet
- class TestV2Vnet(TestCase):
- def test_v2v_net(self):
- """Test V2VNet."""
- model = V2VNet(input_channels=17, output_channels=15)
- input = torch.randn(2, 17, 32, 32, 32)
- output = model(input)
- self.assertIsInstance(output, tuple)
- self.assertEqual(output[-1].shape, (2, 15, 32, 32, 32))
|