test_local_visualizer.py 4.1 KB

  1. import os
  2. from unittest import TestCase
  3. import cv2
  4. import numpy as np
  5. import torch
  6. from mmengine.structures import InstanceData, PixelData
  7. from mmdet.evaluation import INSTANCE_OFFSET
  8. from mmdet.structures import DetDataSample
  9. from mmdet.visualization import DetLocalVisualizer
  10. def _rand_bboxes(num_boxes, h, w):
  11. cx, cy, bw, bh = torch.rand(num_boxes, 4).T
  12. tl_x = ((cx * w) - (w * bw / 2)).clamp(0, w)
  13. tl_y = ((cy * h) - (h * bh / 2)).clamp(0, h)
  14. br_x = ((cx * w) + (w * bw / 2)).clamp(0, w)
  15. br_y = ((cy * h) + (h * bh / 2)).clamp(0, h)
  16. bboxes = torch.stack([tl_x, tl_y, br_x, br_y], dim=0).T
  17. return bboxes
  18. def _create_panoptic_data(num_boxes, h, w):
  19. sem_seg = np.zeros((h, w), dtype=np.int64) + 2
  20. bboxes = _rand_bboxes(num_boxes, h, w).int()
  21. labels = torch.randint(2, (num_boxes, ))
  22. for i in range(num_boxes):
  23. x, y, w, h = bboxes[i]
  24. sem_seg[y:y + h, x:x + w] = (i + 1) * INSTANCE_OFFSET + labels[i]
  25. return sem_seg[None]
  26. class TestDetLocalVisualizer(TestCase):
  27. def test_add_datasample(self):
  28. h = 12
  29. w = 10
  30. num_class = 3
  31. num_bboxes = 5
  32. out_file = 'out_file.jpg'
  33. image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8')
  34. # test gt_instances
  35. gt_instances = InstanceData()
  36. gt_instances.bboxes = _rand_bboxes(num_bboxes, h, w)
  37. gt_instances.labels = torch.randint(0, num_class, (num_bboxes, ))
  38. det_data_sample = DetDataSample()
  39. det_data_sample.gt_instances = gt_instances
  40. det_local_visualizer = DetLocalVisualizer()
  41. det_local_visualizer.add_datasample(
  42. 'image', image, det_data_sample, draw_pred=False)
  43. # test out_file
  44. det_local_visualizer.add_datasample(
  45. 'image',
  46. image,
  47. det_data_sample,
  48. draw_pred=False,
  49. out_file=out_file)
  50. assert os.path.exists(out_file)
  51. drawn_img = cv2.imread(out_file)
  52. assert drawn_img.shape == (h, w, 3)
  53. os.remove(out_file)
  54. # test gt_instances and pred_instances
  55. pred_instances = InstanceData()
  56. pred_instances.bboxes = _rand_bboxes(num_bboxes, h, w)
  57. pred_instances.labels = torch.randint(0, num_class, (num_bboxes, ))
  58. pred_instances.scores = torch.rand((num_bboxes, ))
  59. det_data_sample.pred_instances = pred_instances
  60. det_local_visualizer.add_datasample(
  61. 'image', image, det_data_sample, out_file=out_file)
  62. self._assert_image_and_shape(out_file, (h, w * 2, 3))
  63. det_local_visualizer.add_datasample(
  64. 'image', image, det_data_sample, draw_gt=False, out_file=out_file)
  65. self._assert_image_and_shape(out_file, (h, w, 3))
  66. det_local_visualizer.add_datasample(
  67. 'image',
  68. image,
  69. det_data_sample,
  70. draw_pred=False,
  71. out_file=out_file)
  72. self._assert_image_and_shape(out_file, (h, w, 3))
  73. # test gt_panoptic_seg and pred_panoptic_seg
  74. det_local_visualizer.dataset_meta = dict(classes=('1', '2'))
  75. gt_sem_seg = _create_panoptic_data(num_bboxes, h, w)
  76. panoptic_seg = PixelData(sem_seg=gt_sem_seg)
  77. det_data_sample = DetDataSample()
  78. det_data_sample.gt_panoptic_seg = panoptic_seg
  79. pred_sem_seg = _create_panoptic_data(num_bboxes, h, w)
  80. panoptic_seg = PixelData(sem_seg=pred_sem_seg)
  81. det_data_sample.pred_panoptic_seg = panoptic_seg
  82. det_local_visualizer.add_datasample(
  83. 'image', image, det_data_sample, out_file=out_file)
  84. self._assert_image_and_shape(out_file, (h, w * 2, 3))
  85. # class information must be provided
  86. det_local_visualizer.dataset_meta = {}
  87. with self.assertRaises(AssertionError):
  88. det_local_visualizer.add_datasample(
  89. 'image', image, det_data_sample, out_file=out_file)
  90. def _assert_image_and_shape(self, out_file, out_shape):
  91. assert os.path.exists(out_file)
  92. drawn_img = cv2.imread(out_file)
  93. assert drawn_img.shape == out_shape
  94. os.remove(out_file)