sampling_result.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. import numpy as np
  4. import torch
  5. from torch import Tensor
  6. from mmdet.structures.bbox import BaseBoxes, cat_boxes
  7. from mmdet.utils import util_mixins
  8. from mmdet.utils.util_random import ensure_rng
  9. from ..assigners import AssignResult
  10. def random_boxes(num=1, scale=1, rng=None):
  11. """Simple version of ``kwimage.Boxes.random``
  12. Returns:
  13. Tensor: shape (n, 4) in x1, y1, x2, y2 format.
  14. References:
  15. https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390
  16. Example:
  17. >>> num = 3
  18. >>> scale = 512
  19. >>> rng = 0
  20. >>> boxes = random_boxes(num, scale, rng)
  21. >>> print(boxes)
  22. tensor([[280.9925, 278.9802, 308.6148, 366.1769],
  23. [216.9113, 330.6978, 224.0446, 456.5878],
  24. [405.3632, 196.3221, 493.3953, 270.7942]])
  25. """
  26. rng = ensure_rng(rng)
  27. tlbr = rng.rand(num, 4).astype(np.float32)
  28. tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2])
  29. tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3])
  30. br_x = np.maximum(tlbr[:, 0], tlbr[:, 2])
  31. br_y = np.maximum(tlbr[:, 1], tlbr[:, 3])
  32. tlbr[:, 0] = tl_x * scale
  33. tlbr[:, 1] = tl_y * scale
  34. tlbr[:, 2] = br_x * scale
  35. tlbr[:, 3] = br_y * scale
  36. boxes = torch.from_numpy(tlbr)
  37. return boxes
  38. class SamplingResult(util_mixins.NiceRepr):
  39. """Bbox sampling result.
  40. Args:
  41. pos_inds (Tensor): Indices of positive samples.
  42. neg_inds (Tensor): Indices of negative samples.
  43. priors (Tensor): The priors can be anchors or points,
  44. or the bboxes predicted by the previous stage.
  45. gt_bboxes (Tensor): Ground truth of bboxes.
  46. assign_result (:obj:`AssignResult`): Assigning results.
  47. gt_flags (Tensor): The Ground truth flags.
  48. avg_factor_with_neg (bool): If True, ``avg_factor`` equal to
  49. the number of total priors; Otherwise, it is the number of
  50. positive priors. Defaults to True.
  51. Example:
  52. >>> # xdoctest: +IGNORE_WANT
  53. >>> from mmdet.models.task_modules.samplers.sampling_result import * # NOQA
  54. >>> self = SamplingResult.random(rng=10)
  55. >>> print(f'self = {self}')
  56. self = <SamplingResult({
  57. 'neg_inds': tensor([1, 2, 3, 5, 6, 7, 8,
  58. 9, 10, 11, 12, 13]),
  59. 'neg_priors': torch.Size([12, 4]),
  60. 'num_gts': 1,
  61. 'num_neg': 12,
  62. 'num_pos': 1,
  63. 'avg_factor': 13,
  64. 'pos_assigned_gt_inds': tensor([0]),
  65. 'pos_inds': tensor([0]),
  66. 'pos_is_gt': tensor([1], dtype=torch.uint8),
  67. 'pos_priors': torch.Size([1, 4])
  68. })>
  69. """
  70. def __init__(self,
  71. pos_inds: Tensor,
  72. neg_inds: Tensor,
  73. priors: Tensor,
  74. gt_bboxes: Tensor,
  75. assign_result: AssignResult,
  76. gt_flags: Tensor,
  77. avg_factor_with_neg: bool = True) -> None:
  78. self.pos_inds = pos_inds
  79. self.neg_inds = neg_inds
  80. self.num_pos = max(pos_inds.numel(), 1)
  81. self.num_neg = max(neg_inds.numel(), 1)
  82. self.avg_factor_with_neg = avg_factor_with_neg
  83. self.avg_factor = self.num_pos + self.num_neg \
  84. if avg_factor_with_neg else self.num_pos
  85. self.pos_priors = priors[pos_inds]
  86. self.neg_priors = priors[neg_inds]
  87. self.pos_is_gt = gt_flags[pos_inds]
  88. self.num_gts = gt_bboxes.shape[0]
  89. self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
  90. self.pos_gt_labels = assign_result.labels[pos_inds]
  91. box_dim = gt_bboxes.box_dim if isinstance(gt_bboxes, BaseBoxes) else 4
  92. if gt_bboxes.numel() == 0:
  93. # hack for index error case
  94. assert self.pos_assigned_gt_inds.numel() == 0
  95. self.pos_gt_bboxes = gt_bboxes.view(-1, box_dim)
  96. else:
  97. if len(gt_bboxes.shape) < 2:
  98. gt_bboxes = gt_bboxes.view(-1, box_dim)
  99. self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long()]
  100. @property
  101. def priors(self):
  102. """torch.Tensor: concatenated positive and negative priors"""
  103. return cat_boxes([self.pos_priors, self.neg_priors])
  104. @property
  105. def bboxes(self):
  106. """torch.Tensor: concatenated positive and negative boxes"""
  107. warnings.warn('DeprecationWarning: bboxes is deprecated, '
  108. 'please use "priors" instead')
  109. return self.priors
  110. @property
  111. def pos_bboxes(self):
  112. warnings.warn('DeprecationWarning: pos_bboxes is deprecated, '
  113. 'please use "pos_priors" instead')
  114. return self.pos_priors
  115. @property
  116. def neg_bboxes(self):
  117. warnings.warn('DeprecationWarning: neg_bboxes is deprecated, '
  118. 'please use "neg_priors" instead')
  119. return self.neg_priors
  120. def to(self, device):
  121. """Change the device of the data inplace.
  122. Example:
  123. >>> self = SamplingResult.random()
  124. >>> print(f'self = {self.to(None)}')
  125. >>> # xdoctest: +REQUIRES(--gpu)
  126. >>> print(f'self = {self.to(0)}')
  127. """
  128. _dict = self.__dict__
  129. for key, value in _dict.items():
  130. if isinstance(value, (torch.Tensor, BaseBoxes)):
  131. _dict[key] = value.to(device)
  132. return self
  133. def __nice__(self):
  134. data = self.info.copy()
  135. data['pos_priors'] = data.pop('pos_priors').shape
  136. data['neg_priors'] = data.pop('neg_priors').shape
  137. parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
  138. body = ' ' + ',\n '.join(parts)
  139. return '{\n' + body + '\n}'
  140. @property
  141. def info(self):
  142. """Returns a dictionary of info about the object."""
  143. return {
  144. 'pos_inds': self.pos_inds,
  145. 'neg_inds': self.neg_inds,
  146. 'pos_priors': self.pos_priors,
  147. 'neg_priors': self.neg_priors,
  148. 'pos_is_gt': self.pos_is_gt,
  149. 'num_gts': self.num_gts,
  150. 'pos_assigned_gt_inds': self.pos_assigned_gt_inds,
  151. 'num_pos': self.num_pos,
  152. 'num_neg': self.num_neg,
  153. 'avg_factor': self.avg_factor
  154. }
  155. @classmethod
  156. def random(cls, rng=None, **kwargs):
  157. """
  158. Args:
  159. rng (None | int | numpy.random.RandomState): seed or state.
  160. kwargs (keyword arguments):
  161. - num_preds: Number of predicted boxes.
  162. - num_gts: Number of true boxes.
  163. - p_ignore (float): Probability of a predicted box assigned to
  164. an ignored truth.
  165. - p_assigned (float): probability of a predicted box not being
  166. assigned.
  167. Returns:
  168. :obj:`SamplingResult`: Randomly generated sampling result.
  169. Example:
  170. >>> from mmdet.models.task_modules.samplers.sampling_result import * # NOQA
  171. >>> self = SamplingResult.random()
  172. >>> print(self.__dict__)
  173. """
  174. from mmengine.structures import InstanceData
  175. from mmdet.models.task_modules.assigners import AssignResult
  176. from mmdet.models.task_modules.samplers import RandomSampler
  177. rng = ensure_rng(rng)
  178. # make probabilistic?
  179. num = 32
  180. pos_fraction = 0.5
  181. neg_pos_ub = -1
  182. assign_result = AssignResult.random(rng=rng, **kwargs)
  183. # Note we could just compute an assignment
  184. priors = random_boxes(assign_result.num_preds, rng=rng)
  185. gt_bboxes = random_boxes(assign_result.num_gts, rng=rng)
  186. gt_labels = torch.randint(
  187. 0, 5, (assign_result.num_gts, ), dtype=torch.long)
  188. pred_instances = InstanceData()
  189. pred_instances.priors = priors
  190. gt_instances = InstanceData()
  191. gt_instances.bboxes = gt_bboxes
  192. gt_instances.labels = gt_labels
  193. add_gt_as_proposals = True
  194. sampler = RandomSampler(
  195. num,
  196. pos_fraction,
  197. neg_pos_ub=neg_pos_ub,
  198. add_gt_as_proposals=add_gt_as_proposals,
  199. rng=rng)
  200. self = sampler.sample(
  201. assign_result=assign_result,
  202. pred_instances=pred_instances,
  203. gt_instances=gt_instances)
  204. return self