det_data_sample.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional
  3. from mmengine.structures import BaseDataElement, InstanceData, PixelData
  4. class DetDataSample(BaseDataElement):
  5. """A data structure interface of MMDetection. They are used as interfaces
  6. between different components.
  7. The attributes in ``DetDataSample`` are divided into several parts:
  8. - ``proposals``(InstanceData): Region proposals used in two-stage
  9. detectors.
  10. - ``gt_instances``(InstanceData): Ground truth of instance annotations.
  11. - ``pred_instances``(InstanceData): Instances of model predictions.
  12. - ``ignored_instances``(InstanceData): Instances to be ignored during
  13. training/testing.
  14. - ``gt_panoptic_seg``(PixelData): Ground truth of panoptic
  15. segmentation.
  16. - ``pred_panoptic_seg``(PixelData): Prediction of panoptic
  17. segmentation.
  18. - ``gt_sem_seg``(PixelData): Ground truth of semantic segmentation.
  19. - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.
  20. Examples:
  21. >>> import torch
  22. >>> import numpy as np
  23. >>> from mmengine.structures import InstanceData
  24. >>> from mmdet.structures import DetDataSample
  25. >>> data_sample = DetDataSample()
  26. >>> img_meta = dict(img_shape=(800, 1196),
  27. ... pad_shape=(800, 1216))
  28. >>> gt_instances = InstanceData(metainfo=img_meta)
  29. >>> gt_instances.bboxes = torch.rand((5, 4))
  30. >>> gt_instances.labels = torch.rand((5,))
  31. >>> data_sample.gt_instances = gt_instances
  32. >>> assert 'img_shape' in data_sample.gt_instances.metainfo_keys()
  33. >>> len(data_sample.gt_instances)
  34. 5
  35. >>> print(data_sample)
  36. <DetDataSample(
  37. META INFORMATION
  38. DATA FIELDS
  39. gt_instances: <InstanceData(
  40. META INFORMATION
  41. pad_shape: (800, 1216)
  42. img_shape: (800, 1196)
  43. DATA FIELDS
  44. labels: tensor([0.8533, 0.1550, 0.5433, 0.7294, 0.5098])
  45. bboxes:
  46. tensor([[9.7725e-01, 5.8417e-01, 1.7269e-01, 6.5694e-01],
  47. [1.7894e-01, 5.1780e-01, 7.0590e-01, 4.8589e-01],
  48. [7.0392e-01, 6.6770e-01, 1.7520e-01, 1.4267e-01],
  49. [2.2411e-01, 5.1962e-01, 9.6953e-01, 6.6994e-01],
  50. [4.1338e-01, 2.1165e-01, 2.7239e-04, 6.8477e-01]])
  51. ) at 0x7f21fb1b9190>
  52. ) at 0x7f21fb1b9880>
  53. >>> pred_instances = InstanceData(metainfo=img_meta)
  54. >>> pred_instances.bboxes = torch.rand((5, 4))
  55. >>> pred_instances.scores = torch.rand((5,))
  56. >>> data_sample = DetDataSample(pred_instances=pred_instances)
  57. >>> assert 'pred_instances' in data_sample
  58. >>> data_sample = DetDataSample()
  59. >>> gt_instances_data = dict(
  60. ... bboxes=torch.rand(2, 4),
  61. ... labels=torch.rand(2),
  62. ... masks=np.random.rand(2, 2, 2))
  63. >>> gt_instances = InstanceData(**gt_instances_data)
  64. >>> data_sample.gt_instances = gt_instances
  65. >>> assert 'gt_instances' in data_sample
  66. >>> assert 'masks' in data_sample.gt_instances
  67. >>> data_sample = DetDataSample()
  68. >>> gt_panoptic_seg_data = dict(panoptic_seg=torch.rand(2, 4))
  69. >>> gt_panoptic_seg = PixelData(**gt_panoptic_seg_data)
  70. >>> data_sample.gt_panoptic_seg = gt_panoptic_seg
  71. >>> print(data_sample)
  72. <DetDataSample(
  73. META INFORMATION
  74. DATA FIELDS
  75. _gt_panoptic_seg: <BaseDataElement(
  76. META INFORMATION
  77. DATA FIELDS
  78. panoptic_seg: tensor([[0.7586, 0.1262, 0.2892, 0.9341],
  79. [0.3200, 0.7448, 0.1052, 0.5371]])
  80. ) at 0x7f66c2bb7730>
  81. gt_panoptic_seg: <BaseDataElement(
  82. META INFORMATION
  83. DATA FIELDS
  84. panoptic_seg: tensor([[0.7586, 0.1262, 0.2892, 0.9341],
  85. [0.3200, 0.7448, 0.1052, 0.5371]])
  86. ) at 0x7f66c2bb7730>
  87. ) at 0x7f66c2bb7280>
  88. >>> data_sample = DetDataSample()
  89. >>> gt_segm_seg_data = dict(segm_seg=torch.rand(2, 2, 2))
  90. >>> gt_segm_seg = PixelData(**gt_segm_seg_data)
  91. >>> data_sample.gt_segm_seg = gt_segm_seg
  92. >>> assert 'gt_segm_seg' in data_sample
  93. >>> assert 'segm_seg' in data_sample.gt_segm_seg
  94. """
  95. @property
  96. def proposals(self) -> InstanceData:
  97. return self._proposals
  98. @proposals.setter
  99. def proposals(self, value: InstanceData):
  100. self.set_field(value, '_proposals', dtype=InstanceData)
  101. @proposals.deleter
  102. def proposals(self):
  103. del self._proposals
  104. @property
  105. def gt_instances(self) -> InstanceData:
  106. return self._gt_instances
  107. @gt_instances.setter
  108. def gt_instances(self, value: InstanceData):
  109. self.set_field(value, '_gt_instances', dtype=InstanceData)
  110. @gt_instances.deleter
  111. def gt_instances(self):
  112. del self._gt_instances
  113. @property
  114. def pred_instances(self) -> InstanceData:
  115. return self._pred_instances
  116. @pred_instances.setter
  117. def pred_instances(self, value: InstanceData):
  118. self.set_field(value, '_pred_instances', dtype=InstanceData)
  119. @pred_instances.deleter
  120. def pred_instances(self):
  121. del self._pred_instances
  122. @property
  123. def ignored_instances(self) -> InstanceData:
  124. return self._ignored_instances
  125. @ignored_instances.setter
  126. def ignored_instances(self, value: InstanceData):
  127. self.set_field(value, '_ignored_instances', dtype=InstanceData)
  128. @ignored_instances.deleter
  129. def ignored_instances(self):
  130. del self._ignored_instances
  131. @property
  132. def gt_panoptic_seg(self) -> PixelData:
  133. return self._gt_panoptic_seg
  134. @gt_panoptic_seg.setter
  135. def gt_panoptic_seg(self, value: PixelData):
  136. self.set_field(value, '_gt_panoptic_seg', dtype=PixelData)
  137. @gt_panoptic_seg.deleter
  138. def gt_panoptic_seg(self):
  139. del self._gt_panoptic_seg
  140. @property
  141. def pred_panoptic_seg(self) -> PixelData:
  142. return self._pred_panoptic_seg
  143. @pred_panoptic_seg.setter
  144. def pred_panoptic_seg(self, value: PixelData):
  145. self.set_field(value, '_pred_panoptic_seg', dtype=PixelData)
  146. @pred_panoptic_seg.deleter
  147. def pred_panoptic_seg(self):
  148. del self._pred_panoptic_seg
  149. @property
  150. def gt_sem_seg(self) -> PixelData:
  151. return self._gt_sem_seg
  152. @gt_sem_seg.setter
  153. def gt_sem_seg(self, value: PixelData):
  154. self.set_field(value, '_gt_sem_seg', dtype=PixelData)
  155. @gt_sem_seg.deleter
  156. def gt_sem_seg(self):
  157. del self._gt_sem_seg
  158. @property
  159. def pred_sem_seg(self) -> PixelData:
  160. return self._pred_sem_seg
  161. @pred_sem_seg.setter
  162. def pred_sem_seg(self, value: PixelData):
  163. self.set_field(value, '_pred_sem_seg', dtype=PixelData)
  164. @pred_sem_seg.deleter
  165. def pred_sem_seg(self):
  166. del self._pred_sem_seg
  167. SampleList = List[DetDataSample]
  168. OptSampleList = Optional[SampleList]