instance_balanced_pos_sampler.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import torch
  4. from mmdet.registry import TASK_UTILS
  5. from .random_sampler import RandomSampler
  6. @TASK_UTILS.register_module()
  7. class InstanceBalancedPosSampler(RandomSampler):
  8. """Instance balanced sampler that samples equal number of positive samples
  9. for each instance."""
  10. def _sample_pos(self, assign_result, num_expected, **kwargs):
  11. """Sample positive boxes.
  12. Args:
  13. assign_result (:obj:`AssignResult`): The assigned results of boxes.
  14. num_expected (int): The number of expected positive samples
  15. Returns:
  16. Tensor or ndarray: sampled indices.
  17. """
  18. pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
  19. if pos_inds.numel() != 0:
  20. pos_inds = pos_inds.squeeze(1)
  21. if pos_inds.numel() <= num_expected:
  22. return pos_inds
  23. else:
  24. unique_gt_inds = assign_result.gt_inds[pos_inds].unique()
  25. num_gts = len(unique_gt_inds)
  26. num_per_gt = int(round(num_expected / float(num_gts)) + 1)
  27. sampled_inds = []
  28. for i in unique_gt_inds:
  29. inds = torch.nonzero(
  30. assign_result.gt_inds == i.item(), as_tuple=False)
  31. if inds.numel() != 0:
  32. inds = inds.squeeze(1)
  33. else:
  34. continue
  35. if len(inds) > num_per_gt:
  36. inds = self.random_choice(inds, num_per_gt)
  37. sampled_inds.append(inds)
  38. sampled_inds = torch.cat(sampled_inds)
  39. if len(sampled_inds) < num_expected:
  40. num_extra = num_expected - len(sampled_inds)
  41. extra_inds = np.array(
  42. list(set(pos_inds.cpu()) - set(sampled_inds.cpu())))
  43. if len(extra_inds) > num_extra:
  44. extra_inds = self.random_choice(extra_inds, num_extra)
  45. extra_inds = torch.from_numpy(extra_inds).to(
  46. assign_result.gt_inds.device).long()
  47. sampled_inds = torch.cat([sampled_inds, extra_inds])
  48. elif len(sampled_inds) > num_expected:
  49. sampled_inds = self.random_choice(sampled_inds, num_expected)
  50. return sampled_inds