123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import numpy as np
- import torch
- from mmdet.registry import TASK_UTILS
- from .random_sampler import RandomSampler
- @TASK_UTILS.register_module()
- class IoUBalancedNegSampler(RandomSampler):
- """IoU Balanced Sampling.
- arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019)
- Sampling proposals according to their IoU. `floor_fraction` of needed RoIs
- are sampled from proposals whose IoU are lower than `floor_thr` randomly.
- The others are sampled from proposals whose IoU are higher than
- `floor_thr`. These proposals are sampled from some bins evenly, which are
- split by `num_bins` via IoU evenly.
- Args:
- num (int): number of proposals.
- pos_fraction (float): fraction of positive proposals.
- floor_thr (float): threshold (minimum) IoU for IoU balanced sampling,
- set to -1 if all using IoU balanced sampling.
- floor_fraction (float): sampling fraction of proposals under floor_thr.
- num_bins (int): number of bins in IoU balanced sampling.
- """
- def __init__(self,
- num,
- pos_fraction,
- floor_thr=-1,
- floor_fraction=0,
- num_bins=3,
- **kwargs):
- super(IoUBalancedNegSampler, self).__init__(num, pos_fraction,
- **kwargs)
- assert floor_thr >= 0 or floor_thr == -1
- assert 0 <= floor_fraction <= 1
- assert num_bins >= 1
- self.floor_thr = floor_thr
- self.floor_fraction = floor_fraction
- self.num_bins = num_bins
- def sample_via_interval(self, max_overlaps, full_set, num_expected):
- """Sample according to the iou interval.
- Args:
- max_overlaps (torch.Tensor): IoU between bounding boxes and ground
- truth boxes.
- full_set (set(int)): A full set of indices of boxes。
- num_expected (int): Number of expected samples。
- Returns:
- np.ndarray: Indices of samples
- """
- max_iou = max_overlaps.max()
- iou_interval = (max_iou - self.floor_thr) / self.num_bins
- per_num_expected = int(num_expected / self.num_bins)
- sampled_inds = []
- for i in range(self.num_bins):
- start_iou = self.floor_thr + i * iou_interval
- end_iou = self.floor_thr + (i + 1) * iou_interval
- tmp_set = set(
- np.where(
- np.logical_and(max_overlaps >= start_iou,
- max_overlaps < end_iou))[0])
- tmp_inds = list(tmp_set & full_set)
- if len(tmp_inds) > per_num_expected:
- tmp_sampled_set = self.random_choice(tmp_inds,
- per_num_expected)
- else:
- tmp_sampled_set = np.array(tmp_inds, dtype=np.int64)
- sampled_inds.append(tmp_sampled_set)
- sampled_inds = np.concatenate(sampled_inds)
- if len(sampled_inds) < num_expected:
- num_extra = num_expected - len(sampled_inds)
- extra_inds = np.array(list(full_set - set(sampled_inds)))
- if len(extra_inds) > num_extra:
- extra_inds = self.random_choice(extra_inds, num_extra)
- sampled_inds = np.concatenate([sampled_inds, extra_inds])
- return sampled_inds
- def _sample_neg(self, assign_result, num_expected, **kwargs):
- """Sample negative boxes.
- Args:
- assign_result (:obj:`AssignResult`): The assigned results of boxes.
- num_expected (int): The number of expected negative samples
- Returns:
- Tensor or ndarray: sampled indices.
- """
- neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
- if neg_inds.numel() != 0:
- neg_inds = neg_inds.squeeze(1)
- if len(neg_inds) <= num_expected:
- return neg_inds
- else:
- max_overlaps = assign_result.max_overlaps.cpu().numpy()
- # balance sampling for negative samples
- neg_set = set(neg_inds.cpu().numpy())
- if self.floor_thr > 0:
- floor_set = set(
- np.where(
- np.logical_and(max_overlaps >= 0,
- max_overlaps < self.floor_thr))[0])
- iou_sampling_set = set(
- np.where(max_overlaps >= self.floor_thr)[0])
- elif self.floor_thr == 0:
- floor_set = set(np.where(max_overlaps == 0)[0])
- iou_sampling_set = set(
- np.where(max_overlaps > self.floor_thr)[0])
- else:
- floor_set = set()
- iou_sampling_set = set(
- np.where(max_overlaps > self.floor_thr)[0])
- # for sampling interval calculation
- self.floor_thr = 0
- floor_neg_inds = list(floor_set & neg_set)
- iou_sampling_neg_inds = list(iou_sampling_set & neg_set)
- num_expected_iou_sampling = int(num_expected *
- (1 - self.floor_fraction))
- if len(iou_sampling_neg_inds) > num_expected_iou_sampling:
- if self.num_bins >= 2:
- iou_sampled_inds = self.sample_via_interval(
- max_overlaps, set(iou_sampling_neg_inds),
- num_expected_iou_sampling)
- else:
- iou_sampled_inds = self.random_choice(
- iou_sampling_neg_inds, num_expected_iou_sampling)
- else:
- iou_sampled_inds = np.array(
- iou_sampling_neg_inds, dtype=np.int64)
- num_expected_floor = num_expected - len(iou_sampled_inds)
- if len(floor_neg_inds) > num_expected_floor:
- sampled_floor_inds = self.random_choice(
- floor_neg_inds, num_expected_floor)
- else:
- sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int64)
- sampled_inds = np.concatenate(
- (sampled_floor_inds, iou_sampled_inds))
- if len(sampled_inds) < num_expected:
- num_extra = num_expected - len(sampled_inds)
- extra_inds = np.array(list(neg_set - set(sampled_inds)))
- if len(extra_inds) > num_extra:
- extra_inds = self.random_choice(extra_inds, num_extra)
- sampled_inds = np.concatenate((sampled_inds, extra_inds))
- sampled_inds = torch.from_numpy(sampled_inds).long().to(
- assign_result.gt_inds.device)
- return sampled_inds
|