123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch
- from mmdet.registry import TASK_UTILS
- from mmdet.structures.bbox import bbox2roi
- from .base_sampler import BaseSampler
- @TASK_UTILS.register_module()
- class OHEMSampler(BaseSampler):
- r"""Online Hard Example Mining Sampler described in `Training Region-based
- Object Detectors with Online Hard Example Mining
- <https://arxiv.org/abs/1604.03540>`_.
- """
- def __init__(self,
- num,
- pos_fraction,
- context,
- neg_pos_ub=-1,
- add_gt_as_proposals=True,
- loss_key='loss_cls',
- **kwargs):
- super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub,
- add_gt_as_proposals)
- self.context = context
- if not hasattr(self.context, 'num_stages'):
- self.bbox_head = self.context.bbox_head
- else:
- self.bbox_head = self.context.bbox_head[self.context.current_stage]
- self.loss_key = loss_key
- def hard_mining(self, inds, num_expected, bboxes, labels, feats):
- with torch.no_grad():
- rois = bbox2roi([bboxes])
- if not hasattr(self.context, 'num_stages'):
- bbox_results = self.context._bbox_forward(feats, rois)
- else:
- bbox_results = self.context._bbox_forward(
- self.context.current_stage, feats, rois)
- cls_score = bbox_results['cls_score']
- loss = self.bbox_head.loss(
- cls_score=cls_score,
- bbox_pred=None,
- rois=rois,
- labels=labels,
- label_weights=cls_score.new_ones(cls_score.size(0)),
- bbox_targets=None,
- bbox_weights=None,
- reduction_override='none')[self.loss_key]
- _, topk_loss_inds = loss.topk(num_expected)
- return inds[topk_loss_inds]
- def _sample_pos(self,
- assign_result,
- num_expected,
- bboxes=None,
- feats=None,
- **kwargs):
- """Sample positive boxes.
- Args:
- assign_result (:obj:`AssignResult`): Assigned results
- num_expected (int): Number of expected positive samples
- bboxes (torch.Tensor, optional): Boxes. Defaults to None.
- feats (list[torch.Tensor], optional): Multi-level features.
- Defaults to None.
- Returns:
- torch.Tensor: Indices of positive samples
- """
- # Sample some hard positive samples
- pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
- if pos_inds.numel() != 0:
- pos_inds = pos_inds.squeeze(1)
- if pos_inds.numel() <= num_expected:
- return pos_inds
- else:
- return self.hard_mining(pos_inds, num_expected, bboxes[pos_inds],
- assign_result.labels[pos_inds], feats)
- def _sample_neg(self,
- assign_result,
- num_expected,
- bboxes=None,
- feats=None,
- **kwargs):
- """Sample negative boxes.
- Args:
- assign_result (:obj:`AssignResult`): Assigned results
- num_expected (int): Number of expected negative samples
- bboxes (torch.Tensor, optional): Boxes. Defaults to None.
- feats (list[torch.Tensor], optional): Multi-level features.
- Defaults to None.
- Returns:
- torch.Tensor: Indices of negative samples
- """
- # Sample some hard negative samples
- neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
- if neg_inds.numel() != 0:
- neg_inds = neg_inds.squeeze(1)
- if len(neg_inds) <= num_expected:
- return neg_inds
- else:
- neg_labels = assign_result.labels.new_empty(
- neg_inds.size(0)).fill_(self.bbox_head.num_classes)
- return self.hard_mining(neg_inds, num_expected, bboxes[neg_inds],
- neg_labels, feats)
|