mask_sampling_result.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. """copy from
  3. https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py."""
  4. import torch
  5. from torch import Tensor
  6. from ..assigners import AssignResult
  7. from .sampling_result import SamplingResult
  8. class MaskSamplingResult(SamplingResult):
  9. """Mask sampling result."""
  10. def __init__(self,
  11. pos_inds: Tensor,
  12. neg_inds: Tensor,
  13. masks: Tensor,
  14. gt_masks: Tensor,
  15. assign_result: AssignResult,
  16. gt_flags: Tensor,
  17. avg_factor_with_neg: bool = True) -> None:
  18. self.pos_inds = pos_inds
  19. self.neg_inds = neg_inds
  20. self.num_pos = max(pos_inds.numel(), 1)
  21. self.num_neg = max(neg_inds.numel(), 1)
  22. self.avg_factor = self.num_pos + self.num_neg \
  23. if avg_factor_with_neg else self.num_pos
  24. self.pos_masks = masks[pos_inds]
  25. self.neg_masks = masks[neg_inds]
  26. self.pos_is_gt = gt_flags[pos_inds]
  27. self.num_gts = gt_masks.shape[0]
  28. self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
  29. if gt_masks.numel() == 0:
  30. # hack for index error case
  31. assert self.pos_assigned_gt_inds.numel() == 0
  32. self.pos_gt_masks = torch.empty_like(gt_masks)
  33. else:
  34. self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :]
  35. @property
  36. def masks(self) -> Tensor:
  37. """torch.Tensor: concatenated positive and negative masks."""
  38. return torch.cat([self.pos_masks, self.neg_masks])
  39. def __nice__(self) -> str:
  40. data = self.info.copy()
  41. data['pos_masks'] = data.pop('pos_masks').shape
  42. data['neg_masks'] = data.pop('neg_masks').shape
  43. parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
  44. body = ' ' + ',\n '.join(parts)
  45. return '{\n' + body + '\n}'
  46. @property
  47. def info(self) -> dict:
  48. """Returns a dictionary of info about the object."""
  49. return {
  50. 'pos_inds': self.pos_inds,
  51. 'neg_inds': self.neg_inds,
  52. 'pos_masks': self.pos_masks,
  53. 'neg_masks': self.neg_masks,
  54. 'pos_is_gt': self.pos_is_gt,
  55. 'num_gts': self.num_gts,
  56. 'pos_assigned_gt_inds': self.pos_assigned_gt_inds,
  57. }