pseudo_sampler.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. from mmengine.structures import InstanceData
  4. from mmdet.registry import TASK_UTILS
  5. from ..assigners import AssignResult
  6. from .base_sampler import BaseSampler
  7. from .sampling_result import SamplingResult
  8. @TASK_UTILS.register_module()
  9. class PseudoSampler(BaseSampler):
  10. """A pseudo sampler that does not do sampling actually."""
  11. def __init__(self, **kwargs):
  12. pass
  13. def _sample_pos(self, **kwargs):
  14. """Sample positive samples."""
  15. raise NotImplementedError
  16. def _sample_neg(self, **kwargs):
  17. """Sample negative samples."""
  18. raise NotImplementedError
  19. def sample(self, assign_result: AssignResult, pred_instances: InstanceData,
  20. gt_instances: InstanceData, *args, **kwargs):
  21. """Directly returns the positive and negative indices of samples.
  22. Args:
  23. assign_result (:obj:`AssignResult`): Bbox assigning results.
  24. pred_instances (:obj:`InstanceData`): Instances of model
  25. predictions. It includes ``priors``, and the priors can
  26. be anchors, points, or bboxes predicted by the model,
  27. shape(n, 4).
  28. gt_instances (:obj:`InstanceData`): Ground truth of instance
  29. annotations. It usually includes ``bboxes`` and ``labels``
  30. attributes.
  31. Returns:
  32. :obj:`SamplingResult`: sampler results
  33. """
  34. gt_bboxes = gt_instances.bboxes
  35. priors = pred_instances.priors
  36. pos_inds = torch.nonzero(
  37. assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
  38. neg_inds = torch.nonzero(
  39. assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()
  40. gt_flags = priors.new_zeros(priors.shape[0], dtype=torch.uint8)
  41. sampling_result = SamplingResult(
  42. pos_inds=pos_inds,
  43. neg_inds=neg_inds,
  44. priors=priors,
  45. gt_bboxes=gt_bboxes,
  46. assign_result=assign_result,
  47. gt_flags=gt_flags,
  48. avg_factor_with_neg=False)
  49. return sampling_result