ohem_sampler.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. from mmdet.registry import TASK_UTILS
  4. from mmdet.structures.bbox import bbox2roi
  5. from .base_sampler import BaseSampler
  6. @TASK_UTILS.register_module()
  7. class OHEMSampler(BaseSampler):
  8. r"""Online Hard Example Mining Sampler described in `Training Region-based
  9. Object Detectors with Online Hard Example Mining
  10. <https://arxiv.org/abs/1604.03540>`_.
  11. """
  12. def __init__(self,
  13. num,
  14. pos_fraction,
  15. context,
  16. neg_pos_ub=-1,
  17. add_gt_as_proposals=True,
  18. loss_key='loss_cls',
  19. **kwargs):
  20. super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub,
  21. add_gt_as_proposals)
  22. self.context = context
  23. if not hasattr(self.context, 'num_stages'):
  24. self.bbox_head = self.context.bbox_head
  25. else:
  26. self.bbox_head = self.context.bbox_head[self.context.current_stage]
  27. self.loss_key = loss_key
  28. def hard_mining(self, inds, num_expected, bboxes, labels, feats):
  29. with torch.no_grad():
  30. rois = bbox2roi([bboxes])
  31. if not hasattr(self.context, 'num_stages'):
  32. bbox_results = self.context._bbox_forward(feats, rois)
  33. else:
  34. bbox_results = self.context._bbox_forward(
  35. self.context.current_stage, feats, rois)
  36. cls_score = bbox_results['cls_score']
  37. loss = self.bbox_head.loss(
  38. cls_score=cls_score,
  39. bbox_pred=None,
  40. rois=rois,
  41. labels=labels,
  42. label_weights=cls_score.new_ones(cls_score.size(0)),
  43. bbox_targets=None,
  44. bbox_weights=None,
  45. reduction_override='none')[self.loss_key]
  46. _, topk_loss_inds = loss.topk(num_expected)
  47. return inds[topk_loss_inds]
  48. def _sample_pos(self,
  49. assign_result,
  50. num_expected,
  51. bboxes=None,
  52. feats=None,
  53. **kwargs):
  54. """Sample positive boxes.
  55. Args:
  56. assign_result (:obj:`AssignResult`): Assigned results
  57. num_expected (int): Number of expected positive samples
  58. bboxes (torch.Tensor, optional): Boxes. Defaults to None.
  59. feats (list[torch.Tensor], optional): Multi-level features.
  60. Defaults to None.
  61. Returns:
  62. torch.Tensor: Indices of positive samples
  63. """
  64. # Sample some hard positive samples
  65. pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
  66. if pos_inds.numel() != 0:
  67. pos_inds = pos_inds.squeeze(1)
  68. if pos_inds.numel() <= num_expected:
  69. return pos_inds
  70. else:
  71. return self.hard_mining(pos_inds, num_expected, bboxes[pos_inds],
  72. assign_result.labels[pos_inds], feats)
  73. def _sample_neg(self,
  74. assign_result,
  75. num_expected,
  76. bboxes=None,
  77. feats=None,
  78. **kwargs):
  79. """Sample negative boxes.
  80. Args:
  81. assign_result (:obj:`AssignResult`): Assigned results
  82. num_expected (int): Number of expected negative samples
  83. bboxes (torch.Tensor, optional): Boxes. Defaults to None.
  84. feats (list[torch.Tensor], optional): Multi-level features.
  85. Defaults to None.
  86. Returns:
  87. torch.Tensor: Indices of negative samples
  88. """
  89. # Sample some hard negative samples
  90. neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
  91. if neg_inds.numel() != 0:
  92. neg_inds = neg_inds.squeeze(1)
  93. if len(neg_inds) <= num_expected:
  94. return neg_inds
  95. else:
  96. neg_labels = assign_result.labels.new_empty(
  97. neg_inds.size(0)).fill_(self.bbox_head.num_classes)
  98. return self.hard_mining(neg_inds, num_expected, bboxes[neg_inds],
  99. neg_labels, feats)