test_multilevel_pixel_data.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import numpy as np
  4. import torch
  5. from mmengine.structures import PixelData
  6. from mmpose.structures import MultilevelPixelData
  7. class TestMultilevelPixelData(TestCase):
  8. def get_multi_level_pixel_data(self):
  9. metainfo = dict(num_keypoints=17)
  10. sizes = [(64, 48), (32, 24), (16, 12)]
  11. heatmaps = [np.random.rand(17, h, w) for h, w in sizes]
  12. masks = [torch.rand(1, h, w) for h, w in sizes]
  13. data = MultilevelPixelData(
  14. metainfo=metainfo, heatmaps=heatmaps, masks=masks)
  15. return data
  16. def test_init(self):
  17. data = self.get_multi_level_pixel_data()
  18. self.assertIn('num_keypoints', data)
  19. self.assertTrue(data.nlevel == 3)
  20. self.assertTrue(data.shape == ((64, 48), (32, 24), (16, 12)))
  21. self.assertTrue(isinstance(data[0], PixelData))
  22. def test_setter(self):
  23. # test `set_field`
  24. data = self.get_multi_level_pixel_data()
  25. sizes = [(64, 48), (32, 24), (16, 8)]
  26. offset_maps = [torch.rand(2, h, w) for h, w in sizes]
  27. data.offset_maps = offset_maps
  28. # test `to_tensor`
  29. data = self.get_multi_level_pixel_data()
  30. self.assertTrue(isinstance(data[0].heatmaps, np.ndarray))
  31. data = data.to_tensor()
  32. self.assertTrue(isinstance(data[0].heatmaps, torch.Tensor))
  33. # test `cpu`
  34. data = self.get_multi_level_pixel_data()
  35. self.assertTrue(isinstance(data[0].heatmaps, np.ndarray))
  36. self.assertTrue(isinstance(data[0].masks, torch.Tensor))
  37. self.assertTrue(data[0].masks.device.type == 'cpu')
  38. data = data.cpu()
  39. self.assertTrue(isinstance(data[0].heatmaps, np.ndarray))
  40. self.assertTrue(data[0].masks.device.type == 'cpu')
  41. # test `to`
  42. data = self.get_multi_level_pixel_data()
  43. self.assertTrue(data[0].masks.device.type == 'cpu')
  44. data = data.to('cpu')
  45. self.assertTrue(data[0].masks.device.type == 'cpu')
  46. # test `numpy`
  47. data = self.get_multi_level_pixel_data()
  48. self.assertTrue(isinstance(data[0].masks, torch.Tensor))
  49. data = data.numpy()
  50. self.assertTrue(isinstance(data[0].masks, np.ndarray))
  51. def test_deleter(self):
  52. data = self.get_multi_level_pixel_data()
  53. for key in ['heatmaps', 'masks']:
  54. self.assertIn(key, data)
  55. exec(f'del data.{key}')
  56. self.assertNotIn(key, data)