test_inference.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import os
  2. from pathlib import Path
  3. import numpy as np
  4. import pytest
  5. import torch
  6. from mmdet.apis import inference_detector, init_detector
  7. from mmdet.structures import DetDataSample
  8. from mmdet.utils import register_all_modules
  9. # TODO: Waiting to fix multiple call error bug
  10. register_all_modules()
  11. @pytest.mark.parametrize('config,devices',
  12. [('configs/retinanet/retinanet_r18_fpn_1x_coco.py',
  13. ('cpu', 'cuda'))])
  14. def test_init_detector(config, devices):
  15. assert all([device in ['cpu', 'cuda'] for device in devices])
  16. project_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
  17. project_dir = os.path.join(project_dir, '..')
  18. config_file = os.path.join(project_dir, config)
  19. # test init_detector with config_file: str and cfg_options
  20. cfg_options = dict(
  21. model=dict(
  22. backbone=dict(
  23. depth=18,
  24. init_cfg=dict(
  25. type='Pretrained', checkpoint='torchvision://resnet18'))))
  26. for device in devices:
  27. if device == 'cuda' and not torch.cuda.is_available():
  28. pytest.skip('test requires GPU and torch+cuda')
  29. model = init_detector(
  30. config_file, device=device, cfg_options=cfg_options)
  31. # test init_detector with :obj:`Path`
  32. config_path_object = Path(config_file)
  33. model = init_detector(config_path_object, device=device)
  34. # test init_detector with undesirable type
  35. with pytest.raises(TypeError):
  36. config_list = [config_file]
  37. model = init_detector(config_list) # noqa: F841
  38. @pytest.mark.parametrize('config,devices',
  39. [('configs/retinanet/retinanet_r18_fpn_1x_coco.py',
  40. ('cpu', 'cuda'))])
  41. def test_inference_detector(config, devices):
  42. assert all([device in ['cpu', 'cuda'] for device in devices])
  43. project_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
  44. project_dir = os.path.join(project_dir, '..')
  45. config_file = os.path.join(project_dir, config)
  46. # test init_detector with config_file: str and cfg_options
  47. rng = np.random.RandomState(0)
  48. img1 = rng.randint(0, 255, (100, 100, 3), dtype=np.uint8)
  49. img2 = rng.randint(0, 255, (100, 100, 3), dtype=np.uint8)
  50. for device in devices:
  51. if device == 'cuda' and not torch.cuda.is_available():
  52. pytest.skip('test requires GPU and torch+cuda')
  53. model = init_detector(config_file, device=device)
  54. result = inference_detector(model, img1)
  55. assert isinstance(result, DetDataSample)
  56. result = inference_detector(model, [img1, img2])
  57. assert isinstance(result, list) and len(result) == 2