123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Union
- import torch
- from numpy import ndarray
- from torch import Tensor
- from mmdet.registry import TASK_UTILS
- from ..assigners import AssignResult
- from .base_sampler import BaseSampler
- @TASK_UTILS.register_module()
- class RandomSampler(BaseSampler):
- """Random sampler.
- Args:
- num (int): Number of samples
- pos_fraction (float): Fraction of positive samples
- neg_pos_up (int): Upper bound number of negative and
- positive samples. Defaults to -1.
- add_gt_as_proposals (bool): Whether to add ground truth
- boxes as proposals. Defaults to True.
- """
- def __init__(self,
- num: int,
- pos_fraction: float,
- neg_pos_ub: int = -1,
- add_gt_as_proposals: bool = True,
- **kwargs):
- from .sampling_result import ensure_rng
- super().__init__(
- num=num,
- pos_fraction=pos_fraction,
- neg_pos_ub=neg_pos_ub,
- add_gt_as_proposals=add_gt_as_proposals)
- self.rng = ensure_rng(kwargs.get('rng', None))
- def random_choice(self, gallery: Union[Tensor, ndarray, list],
- num: int) -> Union[Tensor, ndarray]:
- """Random select some elements from the gallery.
- If `gallery` is a Tensor, the returned indices will be a Tensor;
- If `gallery` is a ndarray or list, the returned indices will be a
- ndarray.
- Args:
- gallery (Tensor | ndarray | list): indices pool.
- num (int): expected sample num.
- Returns:
- Tensor or ndarray: sampled indices.
- """
- assert len(gallery) >= num
- is_tensor = isinstance(gallery, torch.Tensor)
- if not is_tensor:
- if torch.cuda.is_available():
- device = torch.cuda.current_device()
- else:
- device = 'cpu'
- gallery = torch.tensor(gallery, dtype=torch.long, device=device)
- # This is a temporary fix. We can revert the following code
- # when PyTorch fixes the abnormal return of torch.randperm.
- # See: https://github.com/open-mmlab/mmdetection/pull/5014
- perm = torch.randperm(gallery.numel())[:num].to(device=gallery.device)
- rand_inds = gallery[perm]
- if not is_tensor:
- rand_inds = rand_inds.cpu().numpy()
- return rand_inds
- def _sample_pos(self, assign_result: AssignResult, num_expected: int,
- **kwargs) -> Union[Tensor, ndarray]:
- """Randomly sample some positive samples.
- Args:
- assign_result (:obj:`AssignResult`): Bbox assigning results.
- num_expected (int): The number of expected positive samples
- Returns:
- Tensor or ndarray: sampled indices.
- """
- pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
- if pos_inds.numel() != 0:
- pos_inds = pos_inds.squeeze(1)
- if pos_inds.numel() <= num_expected:
- return pos_inds
- else:
- return self.random_choice(pos_inds, num_expected)
- def _sample_neg(self, assign_result: AssignResult, num_expected: int,
- **kwargs) -> Union[Tensor, ndarray]:
- """Randomly sample some negative samples.
- Args:
- assign_result (:obj:`AssignResult`): Bbox assigning results.
- num_expected (int): The number of expected positive samples
- Returns:
- Tensor or ndarray: sampled indices.
- """
- neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
- if neg_inds.numel() != 0:
- neg_inds = neg_inds.squeeze(1)
- if len(neg_inds) <= num_expected:
- return neg_inds
- else:
- return self.random_choice(neg_inds, num_expected)
|