assign_result.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. from torch import Tensor
  4. from mmdet.utils import util_mixins
  5. class AssignResult(util_mixins.NiceRepr):
  6. """Stores assignments between predicted and truth boxes.
  7. Attributes:
  8. num_gts (int): the number of truth boxes considered when computing this
  9. assignment
  10. gt_inds (Tensor): for each predicted box indicates the 1-based
  11. index of the assigned truth box. 0 means unassigned and -1 means
  12. ignore.
  13. max_overlaps (Tensor): the iou between the predicted box and its
  14. assigned truth box.
  15. labels (Tensor): If specified, for each predicted box
  16. indicates the category label of the assigned truth box.
  17. Example:
  18. >>> # An assign result between 4 predicted boxes and 9 true boxes
  19. >>> # where only two boxes were assigned.
  20. >>> num_gts = 9
  21. >>> max_overlaps = torch.LongTensor([0, .5, .9, 0])
  22. >>> gt_inds = torch.LongTensor([-1, 1, 2, 0])
  23. >>> labels = torch.LongTensor([0, 3, 4, 0])
  24. >>> self = AssignResult(num_gts, gt_inds, max_overlaps, labels)
  25. >>> print(str(self)) # xdoctest: +IGNORE_WANT
  26. <AssignResult(num_gts=9, gt_inds.shape=(4,), max_overlaps.shape=(4,),
  27. labels.shape=(4,))>
  28. >>> # Force addition of gt labels (when adding gt as proposals)
  29. >>> new_labels = torch.LongTensor([3, 4, 5])
  30. >>> self.add_gt_(new_labels)
  31. >>> print(str(self)) # xdoctest: +IGNORE_WANT
  32. <AssignResult(num_gts=9, gt_inds.shape=(7,), max_overlaps.shape=(7,),
  33. labels.shape=(7,))>
  34. """
  35. def __init__(self, num_gts: int, gt_inds: Tensor, max_overlaps: Tensor,
  36. labels: Tensor) -> None:
  37. self.num_gts = num_gts
  38. self.gt_inds = gt_inds
  39. self.max_overlaps = max_overlaps
  40. self.labels = labels
  41. # Interface for possible user-defined properties
  42. self._extra_properties = {}
  43. @property
  44. def num_preds(self):
  45. """int: the number of predictions in this assignment"""
  46. return len(self.gt_inds)
  47. def set_extra_property(self, key, value):
  48. """Set user-defined new property."""
  49. assert key not in self.info
  50. self._extra_properties[key] = value
  51. def get_extra_property(self, key):
  52. """Get user-defined property."""
  53. return self._extra_properties.get(key, None)
  54. @property
  55. def info(self):
  56. """dict: a dictionary of info about the object"""
  57. basic_info = {
  58. 'num_gts': self.num_gts,
  59. 'num_preds': self.num_preds,
  60. 'gt_inds': self.gt_inds,
  61. 'max_overlaps': self.max_overlaps,
  62. 'labels': self.labels,
  63. }
  64. basic_info.update(self._extra_properties)
  65. return basic_info
  66. def __nice__(self):
  67. """str: a "nice" summary string describing this assign result"""
  68. parts = []
  69. parts.append(f'num_gts={self.num_gts!r}')
  70. if self.gt_inds is None:
  71. parts.append(f'gt_inds={self.gt_inds!r}')
  72. else:
  73. parts.append(f'gt_inds.shape={tuple(self.gt_inds.shape)!r}')
  74. if self.max_overlaps is None:
  75. parts.append(f'max_overlaps={self.max_overlaps!r}')
  76. else:
  77. parts.append('max_overlaps.shape='
  78. f'{tuple(self.max_overlaps.shape)!r}')
  79. if self.labels is None:
  80. parts.append(f'labels={self.labels!r}')
  81. else:
  82. parts.append(f'labels.shape={tuple(self.labels.shape)!r}')
  83. return ', '.join(parts)
  84. @classmethod
  85. def random(cls, **kwargs):
  86. """Create random AssignResult for tests or debugging.
  87. Args:
  88. num_preds: number of predicted boxes
  89. num_gts: number of true boxes
  90. p_ignore (float): probability of a predicted box assigned to an
  91. ignored truth
  92. p_assigned (float): probability of a predicted box not being
  93. assigned
  94. p_use_label (float | bool): with labels or not
  95. rng (None | int | numpy.random.RandomState): seed or state
  96. Returns:
  97. :obj:`AssignResult`: Randomly generated assign results.
  98. Example:
  99. >>> from mmdet.models.task_modules.assigners.assign_result import * # NOQA
  100. >>> self = AssignResult.random()
  101. >>> print(self.info)
  102. """
  103. from ..samplers.sampling_result import ensure_rng
  104. rng = ensure_rng(kwargs.get('rng', None))
  105. num_gts = kwargs.get('num_gts', None)
  106. num_preds = kwargs.get('num_preds', None)
  107. p_ignore = kwargs.get('p_ignore', 0.3)
  108. p_assigned = kwargs.get('p_assigned', 0.7)
  109. num_classes = kwargs.get('num_classes', 3)
  110. if num_gts is None:
  111. num_gts = rng.randint(0, 8)
  112. if num_preds is None:
  113. num_preds = rng.randint(0, 16)
  114. if num_gts == 0:
  115. max_overlaps = torch.zeros(num_preds, dtype=torch.float32)
  116. gt_inds = torch.zeros(num_preds, dtype=torch.int64)
  117. labels = torch.zeros(num_preds, dtype=torch.int64)
  118. else:
  119. import numpy as np
  120. # Create an overlap for each predicted box
  121. max_overlaps = torch.from_numpy(rng.rand(num_preds))
  122. # Construct gt_inds for each predicted box
  123. is_assigned = torch.from_numpy(rng.rand(num_preds) < p_assigned)
  124. # maximum number of assignments constraints
  125. n_assigned = min(num_preds, min(num_gts, is_assigned.sum()))
  126. assigned_idxs = np.where(is_assigned)[0]
  127. rng.shuffle(assigned_idxs)
  128. assigned_idxs = assigned_idxs[0:n_assigned]
  129. assigned_idxs.sort()
  130. is_assigned[:] = 0
  131. is_assigned[assigned_idxs] = True
  132. is_ignore = torch.from_numpy(
  133. rng.rand(num_preds) < p_ignore) & is_assigned
  134. gt_inds = torch.zeros(num_preds, dtype=torch.int64)
  135. true_idxs = np.arange(num_gts)
  136. rng.shuffle(true_idxs)
  137. true_idxs = torch.from_numpy(true_idxs)
  138. gt_inds[is_assigned] = true_idxs[:n_assigned].long()
  139. gt_inds = torch.from_numpy(
  140. rng.randint(1, num_gts + 1, size=num_preds))
  141. gt_inds[is_ignore] = -1
  142. gt_inds[~is_assigned] = 0
  143. max_overlaps[~is_assigned] = 0
  144. if num_classes == 0:
  145. labels = torch.zeros(num_preds, dtype=torch.int64)
  146. else:
  147. labels = torch.from_numpy(
  148. # remind that we set FG labels to [0, num_class-1]
  149. # since mmdet v2.0
  150. # BG cat_id: num_class
  151. rng.randint(0, num_classes, size=num_preds))
  152. labels[~is_assigned] = 0
  153. self = cls(num_gts, gt_inds, max_overlaps, labels)
  154. return self
  155. def add_gt_(self, gt_labels):
  156. """Add ground truth as assigned results.
  157. Args:
  158. gt_labels (torch.Tensor): Labels of gt boxes
  159. """
  160. self_inds = torch.arange(
  161. 1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
  162. self.gt_inds = torch.cat([self_inds, self.gt_inds])
  163. self.max_overlaps = torch.cat(
  164. [self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])
  165. self.labels = torch.cat([gt_labels, self.labels])