base_sampler.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from abc import ABCMeta, abstractmethod
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from mmdet.structures.bbox import BaseBoxes, cat_boxes
  6. from ..assigners import AssignResult
  7. from .sampling_result import SamplingResult
  8. class BaseSampler(metaclass=ABCMeta):
  9. """Base class of samplers.
  10. Args:
  11. num (int): Number of samples
  12. pos_fraction (float): Fraction of positive samples
  13. neg_pos_up (int): Upper bound number of negative and
  14. positive samples. Defaults to -1.
  15. add_gt_as_proposals (bool): Whether to add ground truth
  16. boxes as proposals. Defaults to True.
  17. """
  18. def __init__(self,
  19. num: int,
  20. pos_fraction: float,
  21. neg_pos_ub: int = -1,
  22. add_gt_as_proposals: bool = True,
  23. **kwargs) -> None:
  24. self.num = num
  25. self.pos_fraction = pos_fraction
  26. self.neg_pos_ub = neg_pos_ub
  27. self.add_gt_as_proposals = add_gt_as_proposals
  28. self.pos_sampler = self
  29. self.neg_sampler = self
  30. @abstractmethod
  31. def _sample_pos(self, assign_result: AssignResult, num_expected: int,
  32. **kwargs):
  33. """Sample positive samples."""
  34. pass
  35. @abstractmethod
  36. def _sample_neg(self, assign_result: AssignResult, num_expected: int,
  37. **kwargs):
  38. """Sample negative samples."""
  39. pass
  40. def sample(self, assign_result: AssignResult, pred_instances: InstanceData,
  41. gt_instances: InstanceData, **kwargs) -> SamplingResult:
  42. """Sample positive and negative bboxes.
  43. This is a simple implementation of bbox sampling given candidates,
  44. assigning results and ground truth bboxes.
  45. Args:
  46. assign_result (:obj:`AssignResult`): Assigning results.
  47. pred_instances (:obj:`InstanceData`): Instances of model
  48. predictions. It includes ``priors``, and the priors can
  49. be anchors or points, or the bboxes predicted by the
  50. previous stage, has shape (n, 4). The bboxes predicted by
  51. the current model or stage will be named ``bboxes``,
  52. ``labels``, and ``scores``, the same as the ``InstanceData``
  53. in other places.
  54. gt_instances (:obj:`InstanceData`): Ground truth of instance
  55. annotations. It usually includes ``bboxes``, with shape (k, 4),
  56. and ``labels``, with shape (k, ).
  57. Returns:
  58. :obj:`SamplingResult`: Sampling result.
  59. Example:
  60. >>> from mmengine.structures import InstanceData
  61. >>> from mmdet.models.task_modules.samplers import RandomSampler,
  62. >>> from mmdet.models.task_modules.assigners import AssignResult
  63. >>> from mmdet.models.task_modules.samplers.
  64. ... sampling_result import ensure_rng, random_boxes
  65. >>> rng = ensure_rng(None)
  66. >>> assign_result = AssignResult.random(rng=rng)
  67. >>> pred_instances = InstanceData()
  68. >>> pred_instances.priors = random_boxes(assign_result.num_preds,
  69. ... rng=rng)
  70. >>> gt_instances = InstanceData()
  71. >>> gt_instances.bboxes = random_boxes(assign_result.num_gts,
  72. ... rng=rng)
  73. >>> gt_instances.labels = torch.randint(
  74. ... 0, 5, (assign_result.num_gts,), dtype=torch.long)
  75. >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1,
  76. >>> add_gt_as_proposals=False)
  77. >>> self = self.sample(assign_result, pred_instances, gt_instances)
  78. """
  79. gt_bboxes = gt_instances.bboxes
  80. priors = pred_instances.priors
  81. gt_labels = gt_instances.labels
  82. if len(priors.shape) < 2:
  83. priors = priors[None, :]
  84. gt_flags = priors.new_zeros((priors.shape[0], ), dtype=torch.uint8)
  85. if self.add_gt_as_proposals and len(gt_bboxes) > 0:
  86. # When `gt_bboxes` and `priors` are all box type, convert
  87. # `gt_bboxes` type to `priors` type.
  88. if (isinstance(gt_bboxes, BaseBoxes)
  89. and isinstance(priors, BaseBoxes)):
  90. gt_bboxes_ = gt_bboxes.convert_to(type(priors))
  91. else:
  92. gt_bboxes_ = gt_bboxes
  93. priors = cat_boxes([gt_bboxes_, priors], dim=0)
  94. assign_result.add_gt_(gt_labels)
  95. gt_ones = priors.new_ones(gt_bboxes_.shape[0], dtype=torch.uint8)
  96. gt_flags = torch.cat([gt_ones, gt_flags])
  97. num_expected_pos = int(self.num * self.pos_fraction)
  98. pos_inds = self.pos_sampler._sample_pos(
  99. assign_result, num_expected_pos, bboxes=priors, **kwargs)
  100. # We found that sampled indices have duplicated items occasionally.
  101. # (may be a bug of PyTorch)
  102. pos_inds = pos_inds.unique()
  103. num_sampled_pos = pos_inds.numel()
  104. num_expected_neg = self.num - num_sampled_pos
  105. if self.neg_pos_ub >= 0:
  106. _pos = max(1, num_sampled_pos)
  107. neg_upper_bound = int(self.neg_pos_ub * _pos)
  108. if num_expected_neg > neg_upper_bound:
  109. num_expected_neg = neg_upper_bound
  110. neg_inds = self.neg_sampler._sample_neg(
  111. assign_result, num_expected_neg, bboxes=priors, **kwargs)
  112. neg_inds = neg_inds.unique()
  113. sampling_result = SamplingResult(
  114. pos_inds=pos_inds,
  115. neg_inds=neg_inds,
  116. priors=priors,
  117. gt_bboxes=gt_bboxes,
  118. assign_result=assign_result,
  119. gt_flags=gt_flags)
  120. return sampling_result