iou_balanced_neg_sampler.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import torch
  4. from mmdet.registry import TASK_UTILS
  5. from .random_sampler import RandomSampler
  6. @TASK_UTILS.register_module()
  7. class IoUBalancedNegSampler(RandomSampler):
  8. """IoU Balanced Sampling.
  9. arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019)
  10. Sampling proposals according to their IoU. `floor_fraction` of needed RoIs
  11. are sampled from proposals whose IoU are lower than `floor_thr` randomly.
  12. The others are sampled from proposals whose IoU are higher than
  13. `floor_thr`. These proposals are sampled from some bins evenly, which are
  14. split by `num_bins` via IoU evenly.
  15. Args:
  16. num (int): number of proposals.
  17. pos_fraction (float): fraction of positive proposals.
  18. floor_thr (float): threshold (minimum) IoU for IoU balanced sampling,
  19. set to -1 if all using IoU balanced sampling.
  20. floor_fraction (float): sampling fraction of proposals under floor_thr.
  21. num_bins (int): number of bins in IoU balanced sampling.
  22. """
  23. def __init__(self,
  24. num,
  25. pos_fraction,
  26. floor_thr=-1,
  27. floor_fraction=0,
  28. num_bins=3,
  29. **kwargs):
  30. super(IoUBalancedNegSampler, self).__init__(num, pos_fraction,
  31. **kwargs)
  32. assert floor_thr >= 0 or floor_thr == -1
  33. assert 0 <= floor_fraction <= 1
  34. assert num_bins >= 1
  35. self.floor_thr = floor_thr
  36. self.floor_fraction = floor_fraction
  37. self.num_bins = num_bins
  38. def sample_via_interval(self, max_overlaps, full_set, num_expected):
  39. """Sample according to the iou interval.
  40. Args:
  41. max_overlaps (torch.Tensor): IoU between bounding boxes and ground
  42. truth boxes.
  43. full_set (set(int)): A full set of indices of boxes。
  44. num_expected (int): Number of expected samples。
  45. Returns:
  46. np.ndarray: Indices of samples
  47. """
  48. max_iou = max_overlaps.max()
  49. iou_interval = (max_iou - self.floor_thr) / self.num_bins
  50. per_num_expected = int(num_expected / self.num_bins)
  51. sampled_inds = []
  52. for i in range(self.num_bins):
  53. start_iou = self.floor_thr + i * iou_interval
  54. end_iou = self.floor_thr + (i + 1) * iou_interval
  55. tmp_set = set(
  56. np.where(
  57. np.logical_and(max_overlaps >= start_iou,
  58. max_overlaps < end_iou))[0])
  59. tmp_inds = list(tmp_set & full_set)
  60. if len(tmp_inds) > per_num_expected:
  61. tmp_sampled_set = self.random_choice(tmp_inds,
  62. per_num_expected)
  63. else:
  64. tmp_sampled_set = np.array(tmp_inds, dtype=np.int64)
  65. sampled_inds.append(tmp_sampled_set)
  66. sampled_inds = np.concatenate(sampled_inds)
  67. if len(sampled_inds) < num_expected:
  68. num_extra = num_expected - len(sampled_inds)
  69. extra_inds = np.array(list(full_set - set(sampled_inds)))
  70. if len(extra_inds) > num_extra:
  71. extra_inds = self.random_choice(extra_inds, num_extra)
  72. sampled_inds = np.concatenate([sampled_inds, extra_inds])
  73. return sampled_inds
  74. def _sample_neg(self, assign_result, num_expected, **kwargs):
  75. """Sample negative boxes.
  76. Args:
  77. assign_result (:obj:`AssignResult`): The assigned results of boxes.
  78. num_expected (int): The number of expected negative samples
  79. Returns:
  80. Tensor or ndarray: sampled indices.
  81. """
  82. neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
  83. if neg_inds.numel() != 0:
  84. neg_inds = neg_inds.squeeze(1)
  85. if len(neg_inds) <= num_expected:
  86. return neg_inds
  87. else:
  88. max_overlaps = assign_result.max_overlaps.cpu().numpy()
  89. # balance sampling for negative samples
  90. neg_set = set(neg_inds.cpu().numpy())
  91. if self.floor_thr > 0:
  92. floor_set = set(
  93. np.where(
  94. np.logical_and(max_overlaps >= 0,
  95. max_overlaps < self.floor_thr))[0])
  96. iou_sampling_set = set(
  97. np.where(max_overlaps >= self.floor_thr)[0])
  98. elif self.floor_thr == 0:
  99. floor_set = set(np.where(max_overlaps == 0)[0])
  100. iou_sampling_set = set(
  101. np.where(max_overlaps > self.floor_thr)[0])
  102. else:
  103. floor_set = set()
  104. iou_sampling_set = set(
  105. np.where(max_overlaps > self.floor_thr)[0])
  106. # for sampling interval calculation
  107. self.floor_thr = 0
  108. floor_neg_inds = list(floor_set & neg_set)
  109. iou_sampling_neg_inds = list(iou_sampling_set & neg_set)
  110. num_expected_iou_sampling = int(num_expected *
  111. (1 - self.floor_fraction))
  112. if len(iou_sampling_neg_inds) > num_expected_iou_sampling:
  113. if self.num_bins >= 2:
  114. iou_sampled_inds = self.sample_via_interval(
  115. max_overlaps, set(iou_sampling_neg_inds),
  116. num_expected_iou_sampling)
  117. else:
  118. iou_sampled_inds = self.random_choice(
  119. iou_sampling_neg_inds, num_expected_iou_sampling)
  120. else:
  121. iou_sampled_inds = np.array(
  122. iou_sampling_neg_inds, dtype=np.int64)
  123. num_expected_floor = num_expected - len(iou_sampled_inds)
  124. if len(floor_neg_inds) > num_expected_floor:
  125. sampled_floor_inds = self.random_choice(
  126. floor_neg_inds, num_expected_floor)
  127. else:
  128. sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int64)
  129. sampled_inds = np.concatenate(
  130. (sampled_floor_inds, iou_sampled_inds))
  131. if len(sampled_inds) < num_expected:
  132. num_extra = num_expected - len(sampled_inds)
  133. extra_inds = np.array(list(neg_set - set(sampled_inds)))
  134. if len(extra_inds) > num_extra:
  135. extra_inds = self.random_choice(extra_inds, num_extra)
  136. sampled_inds = np.concatenate((sampled_inds, extra_inds))
  137. sampled_inds = torch.from_numpy(sampled_inds).long().to(
  138. assign_result.gt_inds.device)
  139. return sampled_inds