test_local_visualizer.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import os
  2. from unittest import TestCase
  3. import cv2
  4. import numpy as np
  5. import torch
  6. from mmengine.structures import InstanceData, PixelData
  7. from mmdet.evaluation import INSTANCE_OFFSET
  8. from mmdet.structures import DetDataSample
  9. from mmdet.visualization import DetLocalVisualizer
  10. def _rand_bboxes(num_boxes, h, w):
  11. cx, cy, bw, bh = torch.rand(num_boxes, 4).T
  12. tl_x = ((cx * w) - (w * bw / 2)).clamp(0, w)
  13. tl_y = ((cy * h) - (h * bh / 2)).clamp(0, h)
  14. br_x = ((cx * w) + (w * bw / 2)).clamp(0, w)
  15. br_y = ((cy * h) + (h * bh / 2)).clamp(0, h)
  16. bboxes = torch.stack([tl_x, tl_y, br_x, br_y], dim=0).T
  17. return bboxes
  18. def _create_panoptic_data(num_boxes, h, w):
  19. sem_seg = np.zeros((h, w), dtype=np.int64) + 2
  20. bboxes = _rand_bboxes(num_boxes, h, w).int()
  21. labels = torch.randint(2, (num_boxes, ))
  22. for i in range(num_boxes):
  23. x, y, w, h = bboxes[i]
  24. sem_seg[y:y + h, x:x + w] = (i + 1) * INSTANCE_OFFSET + labels[i]
  25. return sem_seg[None]
  26. class TestDetLocalVisualizer(TestCase):
  27. def test_add_datasample(self):
  28. h = 12
  29. w = 10
  30. num_class = 3
  31. num_bboxes = 5
  32. out_file = 'out_file.jpg'
  33. image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8')
  34. # test gt_instances
  35. gt_instances = InstanceData()
  36. gt_instances.bboxes = _rand_bboxes(num_bboxes, h, w)
  37. gt_instances.labels = torch.randint(0, num_class, (num_bboxes, ))
  38. det_data_sample = DetDataSample()
  39. det_data_sample.gt_instances = gt_instances
  40. det_local_visualizer = DetLocalVisualizer()
  41. det_local_visualizer.add_datasample(
  42. 'image', image, det_data_sample, draw_pred=False)
  43. # test out_file
  44. det_local_visualizer.add_datasample(
  45. 'image',
  46. image,
  47. det_data_sample,
  48. draw_pred=False,
  49. out_file=out_file)
  50. assert os.path.exists(out_file)
  51. drawn_img = cv2.imread(out_file)
  52. assert drawn_img.shape == (h, w, 3)
  53. os.remove(out_file)
  54. # test gt_instances and pred_instances
  55. pred_instances = InstanceData()
  56. pred_instances.bboxes = _rand_bboxes(num_bboxes, h, w)
  57. pred_instances.labels = torch.randint(0, num_class, (num_bboxes, ))
  58. pred_instances.scores = torch.rand((num_bboxes, ))
  59. det_data_sample.pred_instances = pred_instances
  60. det_local_visualizer.add_datasample(
  61. 'image', image, det_data_sample, out_file=out_file)
  62. self._assert_image_and_shape(out_file, (h, w * 2, 3))
  63. det_local_visualizer.add_datasample(
  64. 'image', image, det_data_sample, draw_gt=False, out_file=out_file)
  65. self._assert_image_and_shape(out_file, (h, w, 3))
  66. det_local_visualizer.add_datasample(
  67. 'image',
  68. image,
  69. det_data_sample,
  70. draw_pred=False,
  71. out_file=out_file)
  72. self._assert_image_and_shape(out_file, (h, w, 3))
  73. # test gt_panoptic_seg and pred_panoptic_seg
  74. det_local_visualizer.dataset_meta = dict(classes=('1', '2'))
  75. gt_sem_seg = _create_panoptic_data(num_bboxes, h, w)
  76. panoptic_seg = PixelData(sem_seg=gt_sem_seg)
  77. det_data_sample = DetDataSample()
  78. det_data_sample.gt_panoptic_seg = panoptic_seg
  79. pred_sem_seg = _create_panoptic_data(num_bboxes, h, w)
  80. panoptic_seg = PixelData(sem_seg=pred_sem_seg)
  81. det_data_sample.pred_panoptic_seg = panoptic_seg
  82. det_local_visualizer.add_datasample(
  83. 'image', image, det_data_sample, out_file=out_file)
  84. self._assert_image_and_shape(out_file, (h, w * 2, 3))
  85. # class information must be provided
  86. det_local_visualizer.dataset_meta = {}
  87. with self.assertRaises(AssertionError):
  88. det_local_visualizer.add_datasample(
  89. 'image', image, det_data_sample, out_file=out_file)
  90. def _assert_image_and_shape(self, out_file, out_shape):
  91. assert os.path.exists(out_file)
  92. drawn_img = cv2.imread(out_file)
  93. assert drawn_img.shape == out_shape
  94. os.remove(out_file)