test_v2v_net.py 470 B

1234567891011121314151617
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmpose.models.backbones import V2VNet
  5. class TestV2Vnet(TestCase):
  6. def test_v2v_net(self):
  7. """Test V2VNet."""
  8. model = V2VNet(input_channels=17, output_channels=15)
  9. input = torch.randn(2, 17, 32, 32, 32)
  10. output = model(input)
  11. self.assertIsInstance(output, tuple)
  12. self.assertEqual(output[-1].shape, (2, 15, 32, 32, 32))