test_mmpose_inferencer.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os
  3. import os.path as osp
  4. import platform
  5. import unittest
  6. from collections import defaultdict
  7. from tempfile import TemporaryDirectory
  8. from unittest import TestCase
  9. import mmcv
  10. from mmpose.apis.inferencers import MMPoseInferencer
  11. from mmpose.structures import PoseDataSample
  12. class TestMMPoseInferencer(TestCase):
  13. def test_call(self):
  14. try:
  15. from mmdet.apis.det_inferencer import DetInferencer # noqa: F401
  16. except (ImportError, ModuleNotFoundError):
  17. return unittest.skip('mmdet is not installed')
  18. # top-down model
  19. if platform.system().lower() == 'windows':
  20. # the default human pose estimator utilizes rtmdet-m detector
  21. # through alias, which seems not compatible with windows
  22. det_model = 'demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py'
  23. det_weights = 'https://download.openmmlab.com/mmdetection/v2.0/' \
  24. 'faster_rcnn/faster_rcnn_r50_fpn_1x_coco/' \
  25. 'faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
  26. else:
  27. det_model, det_weights = None, None
  28. inferencer = MMPoseInferencer(
  29. 'human', det_model=det_model, det_weights=det_weights)
  30. img_path = 'tests/data/coco/000000197388.jpg'
  31. img = mmcv.imread(img_path)
  32. # `inputs` is path to an image
  33. inputs = img_path
  34. results1 = next(inferencer(inputs, return_vis=True))
  35. self.assertIn('visualization', results1)
  36. self.assertSequenceEqual(results1['visualization'][0].shape, img.shape)
  37. self.assertIn('predictions', results1)
  38. self.assertIn('keypoints', results1['predictions'][0][0])
  39. self.assertEqual(len(results1['predictions'][0][0]['keypoints']), 17)
  40. # `inputs` is an image array
  41. inputs = img
  42. results2 = next(inferencer(inputs))
  43. self.assertEqual(
  44. len(results1['predictions'][0]), len(results2['predictions'][0]))
  45. self.assertSequenceEqual(results1['predictions'][0][0]['keypoints'],
  46. results2['predictions'][0][0]['keypoints'])
  47. results2 = next(inferencer(inputs, return_datasample=True))
  48. self.assertIsInstance(results2['predictions'][0], PoseDataSample)
  49. # `inputs` is path to a directory
  50. inputs = osp.dirname(img_path)
  51. with TemporaryDirectory() as tmp_dir:
  52. # only save visualizations
  53. for res in inferencer(inputs, vis_out_dir=tmp_dir):
  54. pass
  55. self.assertEqual(len(os.listdir(tmp_dir)), 4)
  56. # save both visualizations and predictions
  57. results3 = defaultdict(list)
  58. for res in inferencer(inputs, out_dir=tmp_dir):
  59. for key in res:
  60. results3[key].extend(res[key])
  61. self.assertEqual(len(os.listdir(f'{tmp_dir}/visualizations')), 4)
  62. self.assertEqual(len(os.listdir(f'{tmp_dir}/predictions')), 4)
  63. self.assertEqual(len(results3['predictions']), 4)
  64. self.assertSequenceEqual(results1['predictions'][0][0]['keypoints'],
  65. results3['predictions'][3][0]['keypoints'])
  66. # `inputs` is path to a video
  67. inputs = 'tests/data/posetrack18/videos/000001_mpiinew_test/' \
  68. '000001_mpiinew_test.mp4'
  69. with TemporaryDirectory() as tmp_dir:
  70. results = defaultdict(list)
  71. for res in inferencer(inputs, out_dir=tmp_dir):
  72. for key in res:
  73. results[key].extend(res[key])
  74. self.assertIn('000001_mpiinew_test.mp4',
  75. os.listdir(f'{tmp_dir}/visualizations'))
  76. self.assertIn('000001_mpiinew_test.json',
  77. os.listdir(f'{tmp_dir}/predictions'))
  78. self.assertTrue(inferencer._video_input)
  79. self.assertIn(len(results['predictions']), (4, 5))