multi_instance_sampling_result.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. from torch import Tensor
  4. from ..assigners import AssignResult
  5. from .sampling_result import SamplingResult
  6. class MultiInstanceSamplingResult(SamplingResult):
  7. """Bbox sampling result. Further encapsulation of SamplingResult. Three
  8. attributes neg_assigned_gt_inds, neg_gt_labels, and neg_gt_bboxes have been
  9. added for SamplingResult.
  10. Args:
  11. pos_inds (Tensor): Indices of positive samples.
  12. neg_inds (Tensor): Indices of negative samples.
  13. priors (Tensor): The priors can be anchors or points,
  14. or the bboxes predicted by the previous stage.
  15. gt_and_ignore_bboxes (Tensor): Ground truth and ignore bboxes.
  16. assign_result (:obj:`AssignResult`): Assigning results.
  17. gt_flags (Tensor): The Ground truth flags.
  18. avg_factor_with_neg (bool): If True, ``avg_factor`` equal to
  19. the number of total priors; Otherwise, it is the number of
  20. positive priors. Defaults to True.
  21. """
  22. def __init__(self,
  23. pos_inds: Tensor,
  24. neg_inds: Tensor,
  25. priors: Tensor,
  26. gt_and_ignore_bboxes: Tensor,
  27. assign_result: AssignResult,
  28. gt_flags: Tensor,
  29. avg_factor_with_neg: bool = True) -> None:
  30. self.neg_assigned_gt_inds = assign_result.gt_inds[neg_inds]
  31. self.neg_gt_labels = assign_result.labels[neg_inds]
  32. if gt_and_ignore_bboxes.numel() == 0:
  33. self.neg_gt_bboxes = torch.empty_like(gt_and_ignore_bboxes).view(
  34. -1, 4)
  35. else:
  36. if len(gt_and_ignore_bboxes.shape) < 2:
  37. gt_and_ignore_bboxes = gt_and_ignore_bboxes.view(-1, 4)
  38. self.neg_gt_bboxes = gt_and_ignore_bboxes[
  39. self.neg_assigned_gt_inds.long(), :]
  40. # To resist the minus 1 operation in `SamplingResult.init()`.
  41. assign_result.gt_inds += 1
  42. super().__init__(
  43. pos_inds=pos_inds,
  44. neg_inds=neg_inds,
  45. priors=priors,
  46. gt_bboxes=gt_and_ignore_bboxes,
  47. assign_result=assign_result,
  48. gt_flags=gt_flags,
  49. avg_factor_with_neg=avg_factor_with_neg)