mask_pseudo_sampler.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. """copy from
  3. https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py."""
  4. import torch
  5. from mmengine.structures import InstanceData
  6. from mmdet.registry import TASK_UTILS
  7. from ..assigners import AssignResult
  8. from .base_sampler import BaseSampler
  9. from .mask_sampling_result import MaskSamplingResult
  10. @TASK_UTILS.register_module()
  11. class MaskPseudoSampler(BaseSampler):
  12. """A pseudo sampler that does not do sampling actually."""
  13. def __init__(self, **kwargs):
  14. pass
  15. def _sample_pos(self, **kwargs):
  16. """Sample positive samples."""
  17. raise NotImplementedError
  18. def _sample_neg(self, **kwargs):
  19. """Sample negative samples."""
  20. raise NotImplementedError
  21. def sample(self, assign_result: AssignResult, pred_instances: InstanceData,
  22. gt_instances: InstanceData, *args, **kwargs):
  23. """Directly returns the positive and negative indices of samples.
  24. Args:
  25. assign_result (:obj:`AssignResult`): Mask assigning results.
  26. pred_instances (:obj:`InstanceData`): Instances of model
  27. predictions. It includes ``scores`` and ``masks`` predicted
  28. by the model.
  29. gt_instances (:obj:`InstanceData`): Ground truth of instance
  30. annotations. It usually includes ``labels`` and ``masks``
  31. attributes.
  32. Returns:
  33. :obj:`SamplingResult`: sampler results
  34. """
  35. pred_masks = pred_instances.masks
  36. gt_masks = gt_instances.masks
  37. pos_inds = torch.nonzero(
  38. assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
  39. neg_inds = torch.nonzero(
  40. assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()
  41. gt_flags = pred_masks.new_zeros(pred_masks.shape[0], dtype=torch.uint8)
  42. sampling_result = MaskSamplingResult(
  43. pos_inds=pos_inds,
  44. neg_inds=neg_inds,
  45. masks=pred_masks,
  46. gt_masks=gt_masks,
  47. assign_result=assign_result,
  48. gt_flags=gt_flags,
  49. avg_factor_with_neg=False)
  50. return sampling_result