test_nms.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import numpy as np
  4. from mmpose.evaluation.functional.nms import nearby_joints_nms
  5. class TestNearbyJointsNMS(TestCase):
  6. def test_nearby_joints_nms(self):
  7. kpts_db = []
  8. keep_pose_inds = nearby_joints_nms(
  9. kpts_db, 0.05, score_per_joint=True, max_dets=1)
  10. self.assertEqual(len(keep_pose_inds), 0)
  11. kpts_db = []
  12. for _ in range(5):
  13. kpts_db.append(
  14. dict(keypoints=np.random.rand(3, 2), score=np.random.rand(3)))
  15. keep_pose_inds = nearby_joints_nms(
  16. kpts_db, 0.05, score_per_joint=True, max_dets=1)
  17. self.assertEqual(len(keep_pose_inds), 1)
  18. self.assertLess(keep_pose_inds[0], 5)
  19. kpts_db = []
  20. for _ in range(5):
  21. kpts_db.append(
  22. dict(keypoints=np.random.rand(3, 2), score=np.random.rand()))
  23. keep_pose_inds = nearby_joints_nms(
  24. kpts_db, 0.05, num_nearby_joints_thr=2)
  25. self.assertLessEqual(len(keep_pose_inds), 5)
  26. self.assertGreater(len(keep_pose_inds), 0)
  27. with self.assertRaises(AssertionError):
  28. _ = nearby_joints_nms(kpts_db, 0, num_nearby_joints_thr=2)
  29. with self.assertRaises(AssertionError):
  30. _ = nearby_joints_nms(kpts_db, 0.05, num_nearby_joints_thr=3)