test_det_data_sample.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. from unittest import TestCase
  2. import numpy as np
  3. import pytest
  4. import torch
  5. from mmengine.structures import InstanceData, PixelData
  6. from mmdet.structures import DetDataSample
  7. def _equal(a, b):
  8. if isinstance(a, (torch.Tensor, np.ndarray)):
  9. return (a == b).all()
  10. else:
  11. return a == b
  12. class TestDetDataSample(TestCase):
  13. def test_init(self):
  14. meta_info = dict(
  15. img_size=[256, 256],
  16. scale_factor=np.array([1.5, 1.5]),
  17. img_shape=torch.rand(4))
  18. det_data_sample = DetDataSample(metainfo=meta_info)
  19. assert 'img_size' in det_data_sample
  20. assert det_data_sample.img_size == [256, 256]
  21. assert det_data_sample.get('img_size') == [256, 256]
  22. def test_setter(self):
  23. det_data_sample = DetDataSample()
  24. # test gt_instances
  25. gt_instances_data = dict(
  26. bboxes=torch.rand(4, 4),
  27. labels=torch.rand(4),
  28. masks=np.random.rand(4, 2, 2))
  29. gt_instances = InstanceData(**gt_instances_data)
  30. det_data_sample.gt_instances = gt_instances
  31. assert 'gt_instances' in det_data_sample
  32. assert _equal(det_data_sample.gt_instances.bboxes,
  33. gt_instances_data['bboxes'])
  34. assert _equal(det_data_sample.gt_instances.labels,
  35. gt_instances_data['labels'])
  36. assert _equal(det_data_sample.gt_instances.masks,
  37. gt_instances_data['masks'])
  38. # test pred_instances
  39. pred_instances_data = dict(
  40. bboxes=torch.rand(2, 4),
  41. labels=torch.rand(2),
  42. masks=np.random.rand(2, 2, 2))
  43. pred_instances = InstanceData(**pred_instances_data)
  44. det_data_sample.pred_instances = pred_instances
  45. assert 'pred_instances' in det_data_sample
  46. assert _equal(det_data_sample.pred_instances.bboxes,
  47. pred_instances_data['bboxes'])
  48. assert _equal(det_data_sample.pred_instances.labels,
  49. pred_instances_data['labels'])
  50. assert _equal(det_data_sample.pred_instances.masks,
  51. pred_instances_data['masks'])
  52. # test proposals
  53. proposals_data = dict(bboxes=torch.rand(4, 4), labels=torch.rand(4))
  54. proposals = InstanceData(**proposals_data)
  55. det_data_sample.proposals = proposals
  56. assert 'proposals' in det_data_sample
  57. assert _equal(det_data_sample.proposals.bboxes,
  58. proposals_data['bboxes'])
  59. assert _equal(det_data_sample.proposals.labels,
  60. proposals_data['labels'])
  61. # test ignored_instances
  62. ignored_instances_data = dict(
  63. bboxes=torch.rand(4, 4), labels=torch.rand(4))
  64. ignored_instances = InstanceData(**ignored_instances_data)
  65. det_data_sample.ignored_instances = ignored_instances
  66. assert 'ignored_instances' in det_data_sample
  67. assert _equal(det_data_sample.ignored_instances.bboxes,
  68. ignored_instances_data['bboxes'])
  69. assert _equal(det_data_sample.ignored_instances.labels,
  70. ignored_instances_data['labels'])
  71. # test gt_panoptic_seg
  72. gt_panoptic_seg_data = dict(panoptic_seg=torch.rand(5, 4))
  73. gt_panoptic_seg = PixelData(**gt_panoptic_seg_data)
  74. det_data_sample.gt_panoptic_seg = gt_panoptic_seg
  75. assert 'gt_panoptic_seg' in det_data_sample
  76. assert _equal(det_data_sample.gt_panoptic_seg.panoptic_seg,
  77. gt_panoptic_seg_data['panoptic_seg'])
  78. # test pred_panoptic_seg
  79. pred_panoptic_seg_data = dict(panoptic_seg=torch.rand(5, 4))
  80. pred_panoptic_seg = PixelData(**pred_panoptic_seg_data)
  81. det_data_sample.pred_panoptic_seg = pred_panoptic_seg
  82. assert 'pred_panoptic_seg' in det_data_sample
  83. assert _equal(det_data_sample.pred_panoptic_seg.panoptic_seg,
  84. pred_panoptic_seg_data['panoptic_seg'])
  85. # test gt_sem_seg
  86. gt_segm_seg_data = dict(segm_seg=torch.rand(5, 4, 2))
  87. gt_segm_seg = PixelData(**gt_segm_seg_data)
  88. det_data_sample.gt_segm_seg = gt_segm_seg
  89. assert 'gt_segm_seg' in det_data_sample
  90. assert _equal(det_data_sample.gt_segm_seg.segm_seg,
  91. gt_segm_seg_data['segm_seg'])
  92. # test pred_segm_seg
  93. pred_segm_seg_data = dict(segm_seg=torch.rand(5, 4, 2))
  94. pred_segm_seg = PixelData(**pred_segm_seg_data)
  95. det_data_sample.pred_segm_seg = pred_segm_seg
  96. assert 'pred_segm_seg' in det_data_sample
  97. assert _equal(det_data_sample.pred_segm_seg.segm_seg,
  98. pred_segm_seg_data['segm_seg'])
  99. # test type error
  100. with pytest.raises(AssertionError):
  101. det_data_sample.pred_instances = torch.rand(2, 4)
  102. with pytest.raises(AssertionError):
  103. det_data_sample.pred_panoptic_seg = torch.rand(2, 4)
  104. with pytest.raises(AssertionError):
  105. det_data_sample.pred_sem_seg = torch.rand(2, 4)
  106. def test_deleter(self):
  107. gt_instances_data = dict(
  108. bboxes=torch.rand(4, 4),
  109. labels=torch.rand(4),
  110. masks=np.random.rand(4, 2, 2))
  111. det_data_sample = DetDataSample()
  112. gt_instances = InstanceData(data=gt_instances_data)
  113. det_data_sample.gt_instances = gt_instances
  114. assert 'gt_instances' in det_data_sample
  115. del det_data_sample.gt_instances
  116. assert 'gt_instances' not in det_data_sample
  117. pred_panoptic_seg_data = torch.rand(5, 4)
  118. pred_panoptic_seg = PixelData(data=pred_panoptic_seg_data)
  119. det_data_sample.pred_panoptic_seg = pred_panoptic_seg
  120. assert 'pred_panoptic_seg' in det_data_sample
  121. del det_data_sample.pred_panoptic_seg
  122. assert 'pred_panoptic_seg' not in det_data_sample
  123. pred_segm_seg_data = dict(segm_seg=torch.rand(5, 4, 2))
  124. pred_segm_seg = PixelData(**pred_segm_seg_data)
  125. det_data_sample.pred_segm_seg = pred_segm_seg
  126. assert 'pred_segm_seg' in det_data_sample
  127. del det_data_sample.pred_segm_seg
  128. assert 'pred_segm_seg' not in det_data_sample