123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from unittest import TestCase
- import numpy as np
- import torch
- from mmengine.structures import PixelData
- from mmpose.structures import MultilevelPixelData
- class TestMultilevelPixelData(TestCase):
- def get_multi_level_pixel_data(self):
- metainfo = dict(num_keypoints=17)
- sizes = [(64, 48), (32, 24), (16, 12)]
- heatmaps = [np.random.rand(17, h, w) for h, w in sizes]
- masks = [torch.rand(1, h, w) for h, w in sizes]
- data = MultilevelPixelData(
- metainfo=metainfo, heatmaps=heatmaps, masks=masks)
- return data
- def test_init(self):
- data = self.get_multi_level_pixel_data()
- self.assertIn('num_keypoints', data)
- self.assertTrue(data.nlevel == 3)
- self.assertTrue(data.shape == ((64, 48), (32, 24), (16, 12)))
- self.assertTrue(isinstance(data[0], PixelData))
- def test_setter(self):
- # test `set_field`
- data = self.get_multi_level_pixel_data()
- sizes = [(64, 48), (32, 24), (16, 8)]
- offset_maps = [torch.rand(2, h, w) for h, w in sizes]
- data.offset_maps = offset_maps
- # test `to_tensor`
- data = self.get_multi_level_pixel_data()
- self.assertTrue(isinstance(data[0].heatmaps, np.ndarray))
- data = data.to_tensor()
- self.assertTrue(isinstance(data[0].heatmaps, torch.Tensor))
- # test `cpu`
- data = self.get_multi_level_pixel_data()
- self.assertTrue(isinstance(data[0].heatmaps, np.ndarray))
- self.assertTrue(isinstance(data[0].masks, torch.Tensor))
- self.assertTrue(data[0].masks.device.type == 'cpu')
- data = data.cpu()
- self.assertTrue(isinstance(data[0].heatmaps, np.ndarray))
- self.assertTrue(data[0].masks.device.type == 'cpu')
- # test `to`
- data = self.get_multi_level_pixel_data()
- self.assertTrue(data[0].masks.device.type == 'cpu')
- data = data.to('cpu')
- self.assertTrue(data[0].masks.device.type == 'cpu')
- # test `numpy`
- data = self.get_multi_level_pixel_data()
- self.assertTrue(isinstance(data[0].masks, torch.Tensor))
- data = data.numpy()
- self.assertTrue(isinstance(data[0].masks, np.ndarray))
- def test_deleter(self):
- data = self.get_multi_level_pixel_data()
- for key in ['heatmaps', 'masks']:
- self.assertIn(key, data)
- exec(f'del data.{key}')
- self.assertNotIn(key, data)
|