multi_instance_random_sampler.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Union
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from numpy import ndarray
  6. from torch import Tensor
  7. from mmdet.registry import TASK_UTILS
  8. from ..assigners import AssignResult
  9. from .multi_instance_sampling_result import MultiInstanceSamplingResult
  10. from .random_sampler import RandomSampler
  11. @TASK_UTILS.register_module()
  12. class MultiInsRandomSampler(RandomSampler):
  13. """Random sampler for multi instance.
  14. Note:
  15. Multi-instance means to predict multiple detection boxes with
  16. one proposal box. `AssignResult` may assign multiple gt boxes
  17. to each proposal box, in this case `RandomSampler` should be
  18. replaced by `MultiInsRandomSampler`
  19. """
  20. def _sample_pos(self, assign_result: AssignResult, num_expected: int,
  21. **kwargs) -> Union[Tensor, ndarray]:
  22. """Randomly sample some positive samples.
  23. Args:
  24. assign_result (:obj:`AssignResult`): Bbox assigning results.
  25. num_expected (int): The number of expected positive samples
  26. Returns:
  27. Tensor or ndarray: sampled indices.
  28. """
  29. pos_inds = torch.nonzero(
  30. assign_result.labels[:, 0] > 0, as_tuple=False)
  31. if pos_inds.numel() != 0:
  32. pos_inds = pos_inds.squeeze(1)
  33. if pos_inds.numel() <= num_expected:
  34. return pos_inds
  35. else:
  36. return self.random_choice(pos_inds, num_expected)
  37. def _sample_neg(self, assign_result: AssignResult, num_expected: int,
  38. **kwargs) -> Union[Tensor, ndarray]:
  39. """Randomly sample some negative samples.
  40. Args:
  41. assign_result (:obj:`AssignResult`): Bbox assigning results.
  42. num_expected (int): The number of expected positive samples
  43. Returns:
  44. Tensor or ndarray: sampled indices.
  45. """
  46. neg_inds = torch.nonzero(
  47. assign_result.labels[:, 0] == 0, as_tuple=False)
  48. if neg_inds.numel() != 0:
  49. neg_inds = neg_inds.squeeze(1)
  50. if len(neg_inds) <= num_expected:
  51. return neg_inds
  52. else:
  53. return self.random_choice(neg_inds, num_expected)
  54. def sample(self, assign_result: AssignResult, pred_instances: InstanceData,
  55. gt_instances: InstanceData,
  56. **kwargs) -> MultiInstanceSamplingResult:
  57. """Sample positive and negative bboxes.
  58. Args:
  59. assign_result (:obj:`AssignResult`): Assigning results from
  60. MultiInstanceAssigner.
  61. pred_instances (:obj:`InstanceData`): Instances of model
  62. predictions. It includes ``priors``, and the priors can
  63. be anchors or points, or the bboxes predicted by the
  64. previous stage, has shape (n, 4). The bboxes predicted by
  65. the current model or stage will be named ``bboxes``,
  66. ``labels``, and ``scores``, the same as the ``InstanceData``
  67. in other places.
  68. gt_instances (:obj:`InstanceData`): Ground truth of instance
  69. annotations. It usually includes ``bboxes``, with shape (k, 4),
  70. and ``labels``, with shape (k, ).
  71. Returns:
  72. :obj:`MultiInstanceSamplingResult`: Sampling result.
  73. """
  74. assert 'batch_gt_instances_ignore' in kwargs, \
  75. 'batch_gt_instances_ignore is necessary for MultiInsRandomSampler'
  76. gt_bboxes = gt_instances.bboxes
  77. ignore_bboxes = kwargs['batch_gt_instances_ignore'].bboxes
  78. gt_and_ignore_bboxes = torch.cat([gt_bboxes, ignore_bboxes], dim=0)
  79. priors = pred_instances.priors
  80. if len(priors.shape) < 2:
  81. priors = priors[None, :]
  82. priors = priors[:, :4]
  83. gt_flags = priors.new_zeros((priors.shape[0], ), dtype=torch.uint8)
  84. priors = torch.cat([priors, gt_and_ignore_bboxes], dim=0)
  85. gt_ones = priors.new_ones(
  86. gt_and_ignore_bboxes.shape[0], dtype=torch.uint8)
  87. gt_flags = torch.cat([gt_flags, gt_ones])
  88. num_expected_pos = int(self.num * self.pos_fraction)
  89. pos_inds = self.pos_sampler._sample_pos(assign_result,
  90. num_expected_pos)
  91. # We found that sampled indices have duplicated items occasionally.
  92. # (may be a bug of PyTorch)
  93. pos_inds = pos_inds.unique()
  94. num_sampled_pos = pos_inds.numel()
  95. num_expected_neg = self.num - num_sampled_pos
  96. if self.neg_pos_ub >= 0:
  97. _pos = max(1, num_sampled_pos)
  98. neg_upper_bound = int(self.neg_pos_ub * _pos)
  99. if num_expected_neg > neg_upper_bound:
  100. num_expected_neg = neg_upper_bound
  101. neg_inds = self.neg_sampler._sample_neg(assign_result,
  102. num_expected_neg)
  103. neg_inds = neg_inds.unique()
  104. sampling_result = MultiInstanceSamplingResult(
  105. pos_inds=pos_inds,
  106. neg_inds=neg_inds,
  107. priors=priors,
  108. gt_and_ignore_bboxes=gt_and_ignore_bboxes,
  109. assign_result=assign_result,
  110. gt_flags=gt_flags)
  111. return sampling_result