combined_sampler.py 760 B

123456789101112131415161718192021
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from mmdet.registry import TASK_UTILS
  3. from .base_sampler import BaseSampler
  4. @TASK_UTILS.register_module()
  5. class CombinedSampler(BaseSampler):
  6. """A sampler that combines positive sampler and negative sampler."""
  7. def __init__(self, pos_sampler, neg_sampler, **kwargs):
  8. super(CombinedSampler, self).__init__(**kwargs)
  9. self.pos_sampler = TASK_UTILS.build(pos_sampler, default_args=kwargs)
  10. self.neg_sampler = TASK_UTILS.build(neg_sampler, default_args=kwargs)
  11. def _sample_pos(self, **kwargs):
  12. """Sample positive samples."""
  13. raise NotImplementedError
  14. def _sample_neg(self, **kwargs):
  15. """Sample negative samples."""
  16. raise NotImplementedError