test_det_inferencer.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. import tempfile
  4. from unittest import TestCase, mock
  5. from unittest.mock import Mock, patch
  6. import mmcv
  7. import mmengine
  8. import numpy as np
  9. import torch
  10. from mmengine.structures import InstanceData
  11. from mmengine.utils import is_list_of
  12. from parameterized import parameterized
  13. from mmdet.apis import DetInferencer
  14. from mmdet.evaluation.functional import get_classes
  15. from mmdet.structures import DetDataSample
  16. class TestDetInferencer(TestCase):
  17. @mock.patch('mmengine.infer.infer._load_checkpoint', return_value=None)
  18. def test_init(self, mock):
  19. # init from metafile
  20. DetInferencer('rtmdet-t')
  21. # init from cfg
  22. DetInferencer('configs/yolox/yolox_tiny_8xb8-300e_coco.py')
  23. def assert_predictions_equal(self, preds1, preds2):
  24. for pred1, pred2 in zip(preds1, preds2):
  25. if 'bboxes' in pred1:
  26. self.assertTrue(
  27. np.allclose(pred1['bboxes'], pred2['bboxes'], 0.1))
  28. if 'scores' in pred1:
  29. self.assertTrue(
  30. np.allclose(pred1['scores'], pred2['scores'], 0.1))
  31. if 'labels' in pred1:
  32. self.assertTrue(np.allclose(pred1['labels'], pred2['labels']))
  33. if 'panoptic_seg_path' in pred1:
  34. self.assertTrue(
  35. pred1['panoptic_seg_path'] == pred2['panoptic_seg_path'])
  36. @parameterized.expand([
  37. 'rtmdet-t', 'mask-rcnn_r50_fpn_1x_coco', 'panoptic_fpn_r50_fpn_1x_coco'
  38. ])
  39. def test_call(self, model):
  40. # single img
  41. img_path = 'tests/data/color.jpg'
  42. mock_load = Mock(return_value=None)
  43. with patch('mmengine.infer.infer._load_checkpoint', mock_load):
  44. inferencer = DetInferencer(model)
  45. # In the case of not loading the pretrained weight, the category
  46. # defaults to COCO 80, so it needs to be replaced.
  47. if model == 'panoptic_fpn_r50_fpn_1x_coco':
  48. inferencer.visualizer.dataset_meta = {
  49. 'classes': get_classes('coco_panoptic'),
  50. 'palette': 'random'
  51. }
  52. res_path = inferencer(img_path, return_vis=True)
  53. # ndarray
  54. img = mmcv.imread(img_path)
  55. res_ndarray = inferencer(img, return_vis=True)
  56. self.assert_predictions_equal(res_path['predictions'],
  57. res_ndarray['predictions'])
  58. self.assertIn('visualization', res_path)
  59. self.assertIn('visualization', res_ndarray)
  60. # multiple images
  61. img_paths = ['tests/data/color.jpg', 'tests/data/gray.jpg']
  62. res_path = inferencer(img_paths, return_vis=True)
  63. # list of ndarray
  64. imgs = [mmcv.imread(p) for p in img_paths]
  65. res_ndarray = inferencer(imgs, return_vis=True)
  66. self.assert_predictions_equal(res_path['predictions'],
  67. res_ndarray['predictions'])
  68. self.assertIn('visualization', res_path)
  69. self.assertIn('visualization', res_ndarray)
  70. # img dir, test different batch sizes
  71. img_dir = 'tests/data/VOCdevkit/VOC2007/JPEGImages/'
  72. res_bs1 = inferencer(img_dir, batch_size=1, return_vis=True)
  73. res_bs3 = inferencer(img_dir, batch_size=3, return_vis=True)
  74. self.assert_predictions_equal(res_bs1['predictions'],
  75. res_bs3['predictions'])
  76. # There is a jitter operation when the mask is drawn,
  77. # so it cannot be asserted.
  78. if model == 'rtmdet-t':
  79. for res_bs1_vis, res_bs3_vis in zip(res_bs1['visualization'],
  80. res_bs3['visualization']):
  81. self.assertTrue(np.allclose(res_bs1_vis, res_bs3_vis))
  82. @parameterized.expand([
  83. 'rtmdet-t', 'mask-rcnn_r50_fpn_1x_coco', 'panoptic_fpn_r50_fpn_1x_coco'
  84. ])
  85. def test_visualize(self, model):
  86. img_paths = ['tests/data/color.jpg', 'tests/data/gray.jpg']
  87. mock_load = Mock(return_value=None)
  88. with patch('mmengine.infer.infer._load_checkpoint', mock_load):
  89. inferencer = DetInferencer(model)
  90. # In the case of not loading the pretrained weight, the category
  91. # defaults to COCO 80, so it needs to be replaced.
  92. if model == 'panoptic_fpn_r50_fpn_1x_coco':
  93. inferencer.visualizer.dataset_meta = {
  94. 'classes': get_classes('coco_panoptic'),
  95. 'palette': 'random'
  96. }
  97. with tempfile.TemporaryDirectory() as tmp_dir:
  98. inferencer(img_paths, out_dir=tmp_dir)
  99. for img_dir in ['color.jpg', 'gray.jpg']:
  100. self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir)))
  101. @parameterized.expand([
  102. 'rtmdet-t', 'mask-rcnn_r50_fpn_1x_coco', 'panoptic_fpn_r50_fpn_1x_coco'
  103. ])
  104. def test_postprocess(self, model):
  105. # return_datasample
  106. img_path = 'tests/data/color.jpg'
  107. mock_load = Mock(return_value=None)
  108. with patch('mmengine.infer.infer._load_checkpoint', mock_load):
  109. inferencer = DetInferencer(model)
  110. # In the case of not loading the pretrained weight, the category
  111. # defaults to COCO 80, so it needs to be replaced.
  112. if model == 'panoptic_fpn_r50_fpn_1x_coco':
  113. inferencer.visualizer.dataset_meta = {
  114. 'classes': get_classes('coco_panoptic'),
  115. 'palette': 'random'
  116. }
  117. res = inferencer(img_path, return_datasample=True)
  118. self.assertTrue(is_list_of(res['predictions'], DetDataSample))
  119. with tempfile.TemporaryDirectory() as tmp_dir:
  120. res = inferencer(img_path, out_dir=tmp_dir, no_save_pred=False)
  121. dumped_res = mmengine.load(
  122. osp.join(tmp_dir, 'preds', 'color.json'))
  123. self.assertEqual(res['predictions'][0], dumped_res)
  124. @mock.patch('mmengine.infer.infer._load_checkpoint', return_value=None)
  125. def test_pred2dict(self, mock):
  126. data_sample = DetDataSample()
  127. data_sample.pred_instances = InstanceData()
  128. data_sample.pred_instances.bboxes = np.array([[0, 0, 1, 1]])
  129. data_sample.pred_instances.labels = np.array([0])
  130. data_sample.pred_instances.scores = torch.FloatTensor([0.9])
  131. res = DetInferencer('rtmdet-t').pred2dict(data_sample)
  132. self.assertListAlmostEqual(res['bboxes'], [[0, 0, 1, 1]])
  133. self.assertListAlmostEqual(res['labels'], [0])
  134. self.assertListAlmostEqual(res['scores'], [0.9])
  135. def assertListAlmostEqual(self, list1, list2, places=7):
  136. for i in range(len(list1)):
  137. if isinstance(list1[i], list):
  138. self.assertListAlmostEqual(list1[i], list2[i], places=places)
  139. else:
  140. self.assertAlmostEqual(list1[i], list2[i], places=places)