123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- import os
- from unittest import TestCase
- import cv2
- import numpy as np
- import torch
- from mmengine.structures import InstanceData, PixelData
- from mmdet.evaluation import INSTANCE_OFFSET
- from mmdet.structures import DetDataSample
- from mmdet.visualization import DetLocalVisualizer
- def _rand_bboxes(num_boxes, h, w):
- cx, cy, bw, bh = torch.rand(num_boxes, 4).T
- tl_x = ((cx * w) - (w * bw / 2)).clamp(0, w)
- tl_y = ((cy * h) - (h * bh / 2)).clamp(0, h)
- br_x = ((cx * w) + (w * bw / 2)).clamp(0, w)
- br_y = ((cy * h) + (h * bh / 2)).clamp(0, h)
- bboxes = torch.stack([tl_x, tl_y, br_x, br_y], dim=0).T
- return bboxes
- def _create_panoptic_data(num_boxes, h, w):
- sem_seg = np.zeros((h, w), dtype=np.int64) + 2
- bboxes = _rand_bboxes(num_boxes, h, w).int()
- labels = torch.randint(2, (num_boxes, ))
- for i in range(num_boxes):
- x, y, w, h = bboxes[i]
- sem_seg[y:y + h, x:x + w] = (i + 1) * INSTANCE_OFFSET + labels[i]
- return sem_seg[None]
- class TestDetLocalVisualizer(TestCase):
- def test_add_datasample(self):
- h = 12
- w = 10
- num_class = 3
- num_bboxes = 5
- out_file = 'out_file.jpg'
- image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8')
- # test gt_instances
- gt_instances = InstanceData()
- gt_instances.bboxes = _rand_bboxes(num_bboxes, h, w)
- gt_instances.labels = torch.randint(0, num_class, (num_bboxes, ))
- det_data_sample = DetDataSample()
- det_data_sample.gt_instances = gt_instances
- det_local_visualizer = DetLocalVisualizer()
- det_local_visualizer.add_datasample(
- 'image', image, det_data_sample, draw_pred=False)
- # test out_file
- det_local_visualizer.add_datasample(
- 'image',
- image,
- det_data_sample,
- draw_pred=False,
- out_file=out_file)
- assert os.path.exists(out_file)
- drawn_img = cv2.imread(out_file)
- assert drawn_img.shape == (h, w, 3)
- os.remove(out_file)
- # test gt_instances and pred_instances
- pred_instances = InstanceData()
- pred_instances.bboxes = _rand_bboxes(num_bboxes, h, w)
- pred_instances.labels = torch.randint(0, num_class, (num_bboxes, ))
- pred_instances.scores = torch.rand((num_bboxes, ))
- det_data_sample.pred_instances = pred_instances
- det_local_visualizer.add_datasample(
- 'image', image, det_data_sample, out_file=out_file)
- self._assert_image_and_shape(out_file, (h, w * 2, 3))
- det_local_visualizer.add_datasample(
- 'image', image, det_data_sample, draw_gt=False, out_file=out_file)
- self._assert_image_and_shape(out_file, (h, w, 3))
- det_local_visualizer.add_datasample(
- 'image',
- image,
- det_data_sample,
- draw_pred=False,
- out_file=out_file)
- self._assert_image_and_shape(out_file, (h, w, 3))
- # test gt_panoptic_seg and pred_panoptic_seg
- det_local_visualizer.dataset_meta = dict(classes=('1', '2'))
- gt_sem_seg = _create_panoptic_data(num_bboxes, h, w)
- panoptic_seg = PixelData(sem_seg=gt_sem_seg)
- det_data_sample = DetDataSample()
- det_data_sample.gt_panoptic_seg = panoptic_seg
- pred_sem_seg = _create_panoptic_data(num_bboxes, h, w)
- panoptic_seg = PixelData(sem_seg=pred_sem_seg)
- det_data_sample.pred_panoptic_seg = panoptic_seg
- det_local_visualizer.add_datasample(
- 'image', image, det_data_sample, out_file=out_file)
- self._assert_image_and_shape(out_file, (h, w * 2, 3))
- # class information must be provided
- det_local_visualizer.dataset_meta = {}
- with self.assertRaises(AssertionError):
- det_local_visualizer.add_datasample(
- 'image', image, det_data_sample, out_file=out_file)
- def _assert_image_and_shape(self, out_file, out_shape):
- assert os.path.exists(out_file)
- drawn_img = cv2.imread(out_file)
- assert drawn_img.shape == out_shape
- os.remove(out_file)
|