random_sampler.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Union
  3. import torch
  4. from numpy import ndarray
  5. from torch import Tensor
  6. from mmdet.registry import TASK_UTILS
  7. from ..assigners import AssignResult
  8. from .base_sampler import BaseSampler
  9. @TASK_UTILS.register_module()
  10. class RandomSampler(BaseSampler):
  11. """Random sampler.
  12. Args:
  13. num (int): Number of samples
  14. pos_fraction (float): Fraction of positive samples
  15. neg_pos_up (int): Upper bound number of negative and
  16. positive samples. Defaults to -1.
  17. add_gt_as_proposals (bool): Whether to add ground truth
  18. boxes as proposals. Defaults to True.
  19. """
  20. def __init__(self,
  21. num: int,
  22. pos_fraction: float,
  23. neg_pos_ub: int = -1,
  24. add_gt_as_proposals: bool = True,
  25. **kwargs):
  26. from .sampling_result import ensure_rng
  27. super().__init__(
  28. num=num,
  29. pos_fraction=pos_fraction,
  30. neg_pos_ub=neg_pos_ub,
  31. add_gt_as_proposals=add_gt_as_proposals)
  32. self.rng = ensure_rng(kwargs.get('rng', None))
  33. def random_choice(self, gallery: Union[Tensor, ndarray, list],
  34. num: int) -> Union[Tensor, ndarray]:
  35. """Random select some elements from the gallery.
  36. If `gallery` is a Tensor, the returned indices will be a Tensor;
  37. If `gallery` is a ndarray or list, the returned indices will be a
  38. ndarray.
  39. Args:
  40. gallery (Tensor | ndarray | list): indices pool.
  41. num (int): expected sample num.
  42. Returns:
  43. Tensor or ndarray: sampled indices.
  44. """
  45. assert len(gallery) >= num
  46. is_tensor = isinstance(gallery, torch.Tensor)
  47. if not is_tensor:
  48. if torch.cuda.is_available():
  49. device = torch.cuda.current_device()
  50. else:
  51. device = 'cpu'
  52. gallery = torch.tensor(gallery, dtype=torch.long, device=device)
  53. # This is a temporary fix. We can revert the following code
  54. # when PyTorch fixes the abnormal return of torch.randperm.
  55. # See: https://github.com/open-mmlab/mmdetection/pull/5014
  56. perm = torch.randperm(gallery.numel())[:num].to(device=gallery.device)
  57. rand_inds = gallery[perm]
  58. if not is_tensor:
  59. rand_inds = rand_inds.cpu().numpy()
  60. return rand_inds
  61. def _sample_pos(self, assign_result: AssignResult, num_expected: int,
  62. **kwargs) -> Union[Tensor, ndarray]:
  63. """Randomly sample some positive samples.
  64. Args:
  65. assign_result (:obj:`AssignResult`): Bbox assigning results.
  66. num_expected (int): The number of expected positive samples
  67. Returns:
  68. Tensor or ndarray: sampled indices.
  69. """
  70. pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
  71. if pos_inds.numel() != 0:
  72. pos_inds = pos_inds.squeeze(1)
  73. if pos_inds.numel() <= num_expected:
  74. return pos_inds
  75. else:
  76. return self.random_choice(pos_inds, num_expected)
  77. def _sample_neg(self, assign_result: AssignResult, num_expected: int,
  78. **kwargs) -> Union[Tensor, ndarray]:
  79. """Randomly sample some negative samples.
  80. Args:
  81. assign_result (:obj:`AssignResult`): Bbox assigning results.
  82. num_expected (int): The number of expected positive samples
  83. Returns:
  84. Tensor or ndarray: sampled indices.
  85. """
  86. neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
  87. if neg_inds.numel() != 0:
  88. neg_inds = neg_inds.squeeze(1)
  89. if len(neg_inds) <= num_expected:
  90. return neg_inds
  91. else:
  92. return self.random_choice(neg_inds, num_expected)