test_inference.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. from pathlib import Path
  4. from tempfile import TemporaryDirectory
  5. from unittest import TestCase
  6. import numpy as np
  7. import torch
  8. from mmcv.image import imread, imwrite
  9. from mmengine.utils import is_list_of
  10. from parameterized import parameterized
  11. from mmpose.apis import inference_bottomup, inference_topdown, init_model
  12. from mmpose.structures import PoseDataSample
  13. from mmpose.testing._utils import _rand_bboxes, get_config_file, get_repo_dir
  14. from mmpose.utils import register_all_modules
  15. class TestInference(TestCase):
  16. def setUp(self) -> None:
  17. register_all_modules()
  18. @parameterized.expand([(('configs/body_2d_keypoint/topdown_heatmap/coco/'
  19. 'td-hm_hrnet-w32_8xb64-210e_coco-256x192.py'),
  20. ('cpu', 'cuda'))])
  21. def test_init_model(self, config, devices):
  22. config_file = get_config_file(config)
  23. for device in devices:
  24. if device == 'cuda' and not torch.cuda.is_available():
  25. # Skip the test if cuda is required but unavailable
  26. continue
  27. # test init_model with str path
  28. _ = init_model(config_file, device=device)
  29. # test init_model with :obj:`Path`
  30. _ = init_model(Path(config_file), device=device)
  31. # test init_detector with undesirable type
  32. with self.assertRaisesRegex(
  33. TypeError, 'config must be a filename or Config object'):
  34. config_list = [config_file]
  35. _ = init_model(config_list)
  36. @parameterized.expand([(('configs/body_2d_keypoint/topdown_heatmap/coco/'
  37. 'td-hm_hrnet-w32_8xb64-210e_coco-256x192.py'),
  38. ('cpu', 'cuda'))])
  39. def test_inference_topdown(self, config, devices):
  40. project_dir = osp.abspath(osp.dirname(osp.dirname(__file__)))
  41. project_dir = osp.join(project_dir, '..')
  42. config_file = osp.join(project_dir, config)
  43. rng = np.random.RandomState(0)
  44. img_w = img_h = 100
  45. img = rng.randint(0, 255, (img_h, img_w, 3), dtype=np.uint8)
  46. bboxes = _rand_bboxes(rng, 2, img_w, img_h)
  47. for device in devices:
  48. if device == 'cuda' and not torch.cuda.is_available():
  49. # Skip the test if cuda is required but unavailable
  50. continue
  51. model = init_model(config_file, device=device)
  52. # test inference with bboxes
  53. results = inference_topdown(model, img, bboxes, bbox_format='xywh')
  54. self.assertTrue(is_list_of(results, PoseDataSample))
  55. self.assertEqual(len(results), 2)
  56. self.assertTrue(results[0].pred_instances.keypoints.shape,
  57. (1, 17, 2))
  58. # test inference without bbox
  59. results = inference_topdown(model, img)
  60. self.assertTrue(is_list_of(results, PoseDataSample))
  61. self.assertEqual(len(results), 1)
  62. self.assertTrue(results[0].pred_instances.keypoints.shape,
  63. (1, 17, 2))
  64. # test inference from image file
  65. with TemporaryDirectory() as tmp_dir:
  66. img_path = osp.join(tmp_dir, 'img.jpg')
  67. imwrite(img, img_path)
  68. results = inference_topdown(model, img_path)
  69. self.assertTrue(is_list_of(results, PoseDataSample))
  70. self.assertEqual(len(results), 1)
  71. self.assertTrue(results[0].pred_instances.keypoints.shape,
  72. (1, 17, 2))
  73. @parameterized.expand([(('configs/body_2d_keypoint/'
  74. 'associative_embedding/coco/'
  75. 'ae_hrnet-w32_8xb24-300e_coco-512x512.py'),
  76. ('cpu', 'cuda'))])
  77. def test_inference_bottomup(self, config, devices):
  78. config_file = get_config_file(config)
  79. img = osp.join(get_repo_dir(), 'tests/data/coco/000000000785.jpg')
  80. for device in devices:
  81. if device == 'cuda' and not torch.cuda.is_available():
  82. # Skip the test if cuda is required but unavailable
  83. continue
  84. model = init_model(config_file, device=device)
  85. # test inference from image
  86. results = inference_bottomup(model, img=imread(img))
  87. self.assertTrue(is_list_of(results, PoseDataSample))
  88. self.assertEqual(len(results), 1)
  89. self.assertTrue(results[0].pred_instances.keypoints.shape,
  90. (1, 17, 2))
  91. # test inference from file
  92. results = inference_bottomup(model, img=img)
  93. self.assertTrue(is_list_of(results, PoseDataSample))
  94. self.assertEqual(len(results), 1)
  95. self.assertTrue(results[0].pred_instances.keypoints.shape,
  96. (1, 17, 2))