test_udp_heatmap.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import numpy as np
  4. from mmpose.codecs import UDPHeatmap
  5. from mmpose.registry import KEYPOINT_CODECS
  6. class TestUDPHeatmap(TestCase):
  7. def setUp(self) -> None:
  8. # name and configs of all test cases
  9. self.configs = [
  10. (
  11. 'udp gaussian',
  12. dict(
  13. type='UDPHeatmap',
  14. input_size=(192, 256),
  15. heatmap_size=(48, 64),
  16. heatmap_type='gaussian',
  17. ),
  18. ),
  19. (
  20. 'udp combined',
  21. dict(
  22. type='UDPHeatmap',
  23. input_size=(192, 256),
  24. heatmap_size=(48, 64),
  25. heatmap_type='combined'),
  26. ),
  27. ]
  28. # The bbox is usually padded so the keypoint will not be near the
  29. # boundary
  30. keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 2)) * [192, 256]
  31. keypoints = np.round(keypoints).astype(np.float32)
  32. keypoints_visible = np.ones((1, 17), dtype=np.float32)
  33. self.data = dict(
  34. keypoints=keypoints, keypoints_visible=keypoints_visible)
  35. def test_encode(self):
  36. keypoints = self.data['keypoints']
  37. keypoints_visible = self.data['keypoints_visible']
  38. for name, cfg in self.configs:
  39. codec = KEYPOINT_CODECS.build(cfg)
  40. encoded = codec.encode(keypoints, keypoints_visible)
  41. if codec.heatmap_type == 'combined':
  42. channel_per_kpt = 3
  43. else:
  44. channel_per_kpt = 1
  45. self.assertEqual(encoded['heatmaps'].shape,
  46. (channel_per_kpt * 17, 64, 48),
  47. f'Failed case: "{name}"')
  48. self.assertEqual(encoded['keypoint_weights'].shape,
  49. (1, 17)), f'Failed case: "{name}"'
  50. def test_decode(self):
  51. for name, cfg in self.configs:
  52. codec = KEYPOINT_CODECS.build(cfg)
  53. if codec.heatmap_type == 'combined':
  54. channel_per_kpt = 3
  55. else:
  56. channel_per_kpt = 1
  57. heatmaps = np.random.rand(channel_per_kpt * 17, 64,
  58. 48).astype(np.float32)
  59. keypoints, scores = codec.decode(heatmaps)
  60. self.assertEqual(keypoints.shape, (1, 17, 2),
  61. f'Failed case: "{name}"')
  62. self.assertEqual(scores.shape, (1, 17), f'Failed case: "{name}"')
  63. def test_cicular_verification(self):
  64. keypoints = self.data['keypoints']
  65. keypoints_visible = self.data['keypoints_visible']
  66. for name, cfg in self.configs:
  67. codec = KEYPOINT_CODECS.build(cfg)
  68. encoded = codec.encode(keypoints, keypoints_visible)
  69. _keypoints, _ = codec.decode(encoded['heatmaps'])
  70. self.assertTrue(
  71. np.allclose(keypoints, _keypoints, atol=10.),
  72. f'Failed case: "{name}",{abs(keypoints - _keypoints) < 5.} ')
  73. def test_errors(self):
  74. # invalid heatmap type
  75. with self.assertRaisesRegex(ValueError, 'invalid `heatmap_type`'):
  76. _ = UDPHeatmap(
  77. input_size=(192, 256),
  78. heatmap_size=(48, 64),
  79. heatmap_type='invalid')
  80. # multiple instance
  81. codec = UDPHeatmap(input_size=(192, 256), heatmap_size=(48, 64))
  82. keypoints = np.random.rand(2, 17, 2)
  83. keypoints_visible = np.random.rand(2, 17)
  84. with self.assertRaisesRegex(AssertionError,
  85. 'only support single-instance'):
  86. codec.encode(keypoints, keypoints_visible)