123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151 |
- from unittest import TestCase
- import numpy as np
- import pytest
- import torch
- from mmengine.structures import InstanceData, PixelData
- from mmdet.structures import DetDataSample
- def _equal(a, b):
- if isinstance(a, (torch.Tensor, np.ndarray)):
- return (a == b).all()
- else:
- return a == b
- class TestDetDataSample(TestCase):
- def test_init(self):
- meta_info = dict(
- img_size=[256, 256],
- scale_factor=np.array([1.5, 1.5]),
- img_shape=torch.rand(4))
- det_data_sample = DetDataSample(metainfo=meta_info)
- assert 'img_size' in det_data_sample
- assert det_data_sample.img_size == [256, 256]
- assert det_data_sample.get('img_size') == [256, 256]
- def test_setter(self):
- det_data_sample = DetDataSample()
- # test gt_instances
- gt_instances_data = dict(
- bboxes=torch.rand(4, 4),
- labels=torch.rand(4),
- masks=np.random.rand(4, 2, 2))
- gt_instances = InstanceData(**gt_instances_data)
- det_data_sample.gt_instances = gt_instances
- assert 'gt_instances' in det_data_sample
- assert _equal(det_data_sample.gt_instances.bboxes,
- gt_instances_data['bboxes'])
- assert _equal(det_data_sample.gt_instances.labels,
- gt_instances_data['labels'])
- assert _equal(det_data_sample.gt_instances.masks,
- gt_instances_data['masks'])
- # test pred_instances
- pred_instances_data = dict(
- bboxes=torch.rand(2, 4),
- labels=torch.rand(2),
- masks=np.random.rand(2, 2, 2))
- pred_instances = InstanceData(**pred_instances_data)
- det_data_sample.pred_instances = pred_instances
- assert 'pred_instances' in det_data_sample
- assert _equal(det_data_sample.pred_instances.bboxes,
- pred_instances_data['bboxes'])
- assert _equal(det_data_sample.pred_instances.labels,
- pred_instances_data['labels'])
- assert _equal(det_data_sample.pred_instances.masks,
- pred_instances_data['masks'])
- # test proposals
- proposals_data = dict(bboxes=torch.rand(4, 4), labels=torch.rand(4))
- proposals = InstanceData(**proposals_data)
- det_data_sample.proposals = proposals
- assert 'proposals' in det_data_sample
- assert _equal(det_data_sample.proposals.bboxes,
- proposals_data['bboxes'])
- assert _equal(det_data_sample.proposals.labels,
- proposals_data['labels'])
- # test ignored_instances
- ignored_instances_data = dict(
- bboxes=torch.rand(4, 4), labels=torch.rand(4))
- ignored_instances = InstanceData(**ignored_instances_data)
- det_data_sample.ignored_instances = ignored_instances
- assert 'ignored_instances' in det_data_sample
- assert _equal(det_data_sample.ignored_instances.bboxes,
- ignored_instances_data['bboxes'])
- assert _equal(det_data_sample.ignored_instances.labels,
- ignored_instances_data['labels'])
- # test gt_panoptic_seg
- gt_panoptic_seg_data = dict(panoptic_seg=torch.rand(5, 4))
- gt_panoptic_seg = PixelData(**gt_panoptic_seg_data)
- det_data_sample.gt_panoptic_seg = gt_panoptic_seg
- assert 'gt_panoptic_seg' in det_data_sample
- assert _equal(det_data_sample.gt_panoptic_seg.panoptic_seg,
- gt_panoptic_seg_data['panoptic_seg'])
- # test pred_panoptic_seg
- pred_panoptic_seg_data = dict(panoptic_seg=torch.rand(5, 4))
- pred_panoptic_seg = PixelData(**pred_panoptic_seg_data)
- det_data_sample.pred_panoptic_seg = pred_panoptic_seg
- assert 'pred_panoptic_seg' in det_data_sample
- assert _equal(det_data_sample.pred_panoptic_seg.panoptic_seg,
- pred_panoptic_seg_data['panoptic_seg'])
- # test gt_sem_seg
- gt_segm_seg_data = dict(segm_seg=torch.rand(5, 4, 2))
- gt_segm_seg = PixelData(**gt_segm_seg_data)
- det_data_sample.gt_segm_seg = gt_segm_seg
- assert 'gt_segm_seg' in det_data_sample
- assert _equal(det_data_sample.gt_segm_seg.segm_seg,
- gt_segm_seg_data['segm_seg'])
- # test pred_segm_seg
- pred_segm_seg_data = dict(segm_seg=torch.rand(5, 4, 2))
- pred_segm_seg = PixelData(**pred_segm_seg_data)
- det_data_sample.pred_segm_seg = pred_segm_seg
- assert 'pred_segm_seg' in det_data_sample
- assert _equal(det_data_sample.pred_segm_seg.segm_seg,
- pred_segm_seg_data['segm_seg'])
- # test type error
- with pytest.raises(AssertionError):
- det_data_sample.pred_instances = torch.rand(2, 4)
- with pytest.raises(AssertionError):
- det_data_sample.pred_panoptic_seg = torch.rand(2, 4)
- with pytest.raises(AssertionError):
- det_data_sample.pred_sem_seg = torch.rand(2, 4)
- def test_deleter(self):
- gt_instances_data = dict(
- bboxes=torch.rand(4, 4),
- labels=torch.rand(4),
- masks=np.random.rand(4, 2, 2))
- det_data_sample = DetDataSample()
- gt_instances = InstanceData(data=gt_instances_data)
- det_data_sample.gt_instances = gt_instances
- assert 'gt_instances' in det_data_sample
- del det_data_sample.gt_instances
- assert 'gt_instances' not in det_data_sample
- pred_panoptic_seg_data = torch.rand(5, 4)
- pred_panoptic_seg = PixelData(data=pred_panoptic_seg_data)
- det_data_sample.pred_panoptic_seg = pred_panoptic_seg
- assert 'pred_panoptic_seg' in det_data_sample
- del det_data_sample.pred_panoptic_seg
- assert 'pred_panoptic_seg' not in det_data_sample
- pred_segm_seg_data = dict(segm_seg=torch.rand(5, 4, 2))
- pred_segm_seg = PixelData(**pred_segm_seg_data)
- det_data_sample.pred_segm_seg = pred_segm_seg
- assert 'pred_segm_seg' in det_data_sample
- del det_data_sample.pred_segm_seg
- assert 'pred_segm_seg' not in det_data_sample
|