test_keypoint_eval.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. from numpy.testing import assert_array_almost_equal
  4. from mmpose.evaluation.functional import (keypoint_auc, keypoint_epe,
  5. keypoint_nme, keypoint_pck_accuracy,
  6. multilabel_classification_accuracy,
  7. pose_pck_accuracy)
  8. def test_keypoint_pck_accuracy():
  9. output = np.zeros((2, 5, 2))
  10. target = np.zeros((2, 5, 2))
  11. mask = np.array([[True, True, False, True, True],
  12. [True, True, False, True, True]])
  13. thr = np.full((2, 2), 10, dtype=np.float32)
  14. # first channel
  15. output[0, 0] = [10, 0]
  16. target[0, 0] = [10, 0]
  17. # second channel
  18. output[0, 1] = [20, 20]
  19. target[0, 1] = [10, 10]
  20. # third channel
  21. output[0, 2] = [0, 0]
  22. target[0, 2] = [-1, 0]
  23. # fourth channel
  24. output[0, 3] = [30, 30]
  25. target[0, 3] = [30, 30]
  26. # fifth channel
  27. output[0, 4] = [0, 10]
  28. target[0, 4] = [0, 10]
  29. acc, avg_acc, cnt = keypoint_pck_accuracy(output, target, mask, 0.5, thr)
  30. assert_array_almost_equal(acc, np.array([1, 0.5, -1, 1, 1]), decimal=4)
  31. assert abs(avg_acc - 0.875) < 1e-4
  32. assert abs(cnt - 4) < 1e-4
  33. acc, avg_acc, cnt = keypoint_pck_accuracy(output, target, mask, 0.5,
  34. np.zeros((2, 2)))
  35. assert_array_almost_equal(acc, np.array([-1, -1, -1, -1, -1]), decimal=4)
  36. assert abs(avg_acc) < 1e-4
  37. assert abs(cnt) < 1e-4
  38. acc, avg_acc, cnt = keypoint_pck_accuracy(output, target, mask, 0.5,
  39. np.array([[0, 0], [10, 10]]))
  40. assert_array_almost_equal(acc, np.array([1, 1, -1, 1, 1]), decimal=4)
  41. assert abs(avg_acc - 1) < 1e-4
  42. assert abs(cnt - 4) < 1e-4
  43. def test_keypoint_auc():
  44. output = np.zeros((1, 5, 2))
  45. target = np.zeros((1, 5, 2))
  46. mask = np.array([[True, True, False, True, True]])
  47. # first channel
  48. output[0, 0] = [10, 4]
  49. target[0, 0] = [10, 0]
  50. # second channel
  51. output[0, 1] = [10, 18]
  52. target[0, 1] = [10, 10]
  53. # third channel
  54. output[0, 2] = [0, 0]
  55. target[0, 2] = [0, -1]
  56. # fourth channel
  57. output[0, 3] = [40, 40]
  58. target[0, 3] = [30, 30]
  59. # fifth channel
  60. output[0, 4] = [20, 10]
  61. target[0, 4] = [0, 10]
  62. auc = keypoint_auc(output, target, mask, 20, 4)
  63. assert abs(auc - 0.375) < 1e-4
  64. def test_keypoint_epe():
  65. output = np.zeros((1, 5, 2))
  66. target = np.zeros((1, 5, 2))
  67. mask = np.array([[True, True, False, True, True]])
  68. # first channel
  69. output[0, 0] = [10, 4]
  70. target[0, 0] = [10, 0]
  71. # second channel
  72. output[0, 1] = [10, 18]
  73. target[0, 1] = [10, 10]
  74. # third channel
  75. output[0, 2] = [0, 0]
  76. target[0, 2] = [-1, -1]
  77. # fourth channel
  78. output[0, 3] = [40, 40]
  79. target[0, 3] = [30, 30]
  80. # fifth channel
  81. output[0, 4] = [20, 10]
  82. target[0, 4] = [0, 10]
  83. epe = keypoint_epe(output, target, mask)
  84. assert abs(epe - 11.5355339) < 1e-4
  85. def test_keypoint_nme():
  86. output = np.zeros((1, 5, 2))
  87. target = np.zeros((1, 5, 2))
  88. mask = np.array([[True, True, False, True, True]])
  89. # first channel
  90. output[0, 0] = [10, 4]
  91. target[0, 0] = [10, 0]
  92. # second channel
  93. output[0, 1] = [10, 18]
  94. target[0, 1] = [10, 10]
  95. # third channel
  96. output[0, 2] = [0, 0]
  97. target[0, 2] = [-1, -1]
  98. # fourth channel
  99. output[0, 3] = [40, 40]
  100. target[0, 3] = [30, 30]
  101. # fifth channel
  102. output[0, 4] = [20, 10]
  103. target[0, 4] = [0, 10]
  104. normalize_factor = np.ones((output.shape[0], output.shape[2]))
  105. nme = keypoint_nme(output, target, mask, normalize_factor)
  106. assert abs(nme - 11.5355339) < 1e-4
  107. def test_pose_pck_accuracy():
  108. output = np.zeros((1, 5, 64, 64), dtype=np.float32)
  109. target = np.zeros((1, 5, 64, 64), dtype=np.float32)
  110. mask = np.array([[True, True, False, False, False]])
  111. # first channel
  112. output[0, 0, 20, 20] = 1
  113. target[0, 0, 10, 10] = 1
  114. # second channel
  115. output[0, 1, 30, 30] = 1
  116. target[0, 1, 30, 30] = 1
  117. acc, avg_acc, cnt = pose_pck_accuracy(output, target, mask)
  118. assert_array_almost_equal(acc, np.array([0, 1, -1, -1, -1]), decimal=4)
  119. assert abs(avg_acc - 0.5) < 1e-4
  120. assert abs(cnt - 2) < 1e-4
  121. def test_multilabel_classification_accuracy():
  122. output = np.array([[0.7, 0.8, 0.4], [0.8, 0.1, 0.1]])
  123. target = np.array([[1, 0, 0], [1, 0, 1]])
  124. mask = np.array([[True, True, True], [True, True, True]])
  125. thr = 0.5
  126. acc = multilabel_classification_accuracy(output, target, mask, thr)
  127. assert acc == 0
  128. output = np.array([[0.7, 0.2, 0.4], [0.8, 0.1, 0.9]])
  129. thr = 0.5
  130. acc = multilabel_classification_accuracy(output, target, mask, thr)
  131. assert acc == 1
  132. thr = 0.3
  133. acc = multilabel_classification_accuracy(output, target, mask, thr)
  134. assert acc == 0.5
  135. mask = np.array([[True, True, False], [True, True, True]])
  136. acc = multilabel_classification_accuracy(output, target, mask, thr)
  137. assert acc == 1