1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import numpy as np
- import torch
- from mmdet.registry import TASK_UTILS
- from .random_sampler import RandomSampler
- @TASK_UTILS.register_module()
- class InstanceBalancedPosSampler(RandomSampler):
- """Instance balanced sampler that samples equal number of positive samples
- for each instance."""
- def _sample_pos(self, assign_result, num_expected, **kwargs):
- """Sample positive boxes.
- Args:
- assign_result (:obj:`AssignResult`): The assigned results of boxes.
- num_expected (int): The number of expected positive samples
- Returns:
- Tensor or ndarray: sampled indices.
- """
- pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
- if pos_inds.numel() != 0:
- pos_inds = pos_inds.squeeze(1)
- if pos_inds.numel() <= num_expected:
- return pos_inds
- else:
- unique_gt_inds = assign_result.gt_inds[pos_inds].unique()
- num_gts = len(unique_gt_inds)
- num_per_gt = int(round(num_expected / float(num_gts)) + 1)
- sampled_inds = []
- for i in unique_gt_inds:
- inds = torch.nonzero(
- assign_result.gt_inds == i.item(), as_tuple=False)
- if inds.numel() != 0:
- inds = inds.squeeze(1)
- else:
- continue
- if len(inds) > num_per_gt:
- inds = self.random_choice(inds, num_per_gt)
- sampled_inds.append(inds)
- sampled_inds = torch.cat(sampled_inds)
- if len(sampled_inds) < num_expected:
- num_extra = num_expected - len(sampled_inds)
- extra_inds = np.array(
- list(set(pos_inds.cpu()) - set(sampled_inds.cpu())))
- if len(extra_inds) > num_extra:
- extra_inds = self.random_choice(extra_inds, num_extra)
- extra_inds = torch.from_numpy(extra_inds).to(
- assign_result.gt_inds.device).long()
- sampled_inds = torch.cat([sampled_inds, extra_inds])
- elif len(sampled_inds) > num_expected:
- sampled_inds = self.random_choice(sampled_inds, num_expected)
- return sampled_inds
|