center_region_assigner.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Tuple
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from torch import Tensor
  6. from mmdet.registry import TASK_UTILS
  7. from mmdet.utils import ConfigType
  8. from .assign_result import AssignResult
  9. from .base_assigner import BaseAssigner
  10. def scale_boxes(bboxes: Tensor, scale: float) -> Tensor:
  11. """Expand an array of boxes by a given scale.
  12. Args:
  13. bboxes (Tensor): Shape (m, 4)
  14. scale (float): The scale factor of bboxes
  15. Returns:
  16. Tensor: Shape (m, 4). Scaled bboxes
  17. """
  18. assert bboxes.size(1) == 4
  19. w_half = (bboxes[:, 2] - bboxes[:, 0]) * .5
  20. h_half = (bboxes[:, 3] - bboxes[:, 1]) * .5
  21. x_c = (bboxes[:, 2] + bboxes[:, 0]) * .5
  22. y_c = (bboxes[:, 3] + bboxes[:, 1]) * .5
  23. w_half *= scale
  24. h_half *= scale
  25. boxes_scaled = torch.zeros_like(bboxes)
  26. boxes_scaled[:, 0] = x_c - w_half
  27. boxes_scaled[:, 2] = x_c + w_half
  28. boxes_scaled[:, 1] = y_c - h_half
  29. boxes_scaled[:, 3] = y_c + h_half
  30. return boxes_scaled
  31. def is_located_in(points: Tensor, bboxes: Tensor) -> Tensor:
  32. """Are points located in bboxes.
  33. Args:
  34. points (Tensor): Points, shape: (m, 2).
  35. bboxes (Tensor): Bounding boxes, shape: (n, 4).
  36. Return:
  37. Tensor: Flags indicating if points are located in bboxes,
  38. shape: (m, n).
  39. """
  40. assert points.size(1) == 2
  41. assert bboxes.size(1) == 4
  42. return (points[:, 0].unsqueeze(1) > bboxes[:, 0].unsqueeze(0)) & \
  43. (points[:, 0].unsqueeze(1) < bboxes[:, 2].unsqueeze(0)) & \
  44. (points[:, 1].unsqueeze(1) > bboxes[:, 1].unsqueeze(0)) & \
  45. (points[:, 1].unsqueeze(1) < bboxes[:, 3].unsqueeze(0))
  46. def bboxes_area(bboxes: Tensor) -> Tensor:
  47. """Compute the area of an array of bboxes.
  48. Args:
  49. bboxes (Tensor): The coordinates ox bboxes. Shape: (m, 4)
  50. Returns:
  51. Tensor: Area of the bboxes. Shape: (m, )
  52. """
  53. assert bboxes.size(1) == 4
  54. w = (bboxes[:, 2] - bboxes[:, 0])
  55. h = (bboxes[:, 3] - bboxes[:, 1])
  56. areas = w * h
  57. return areas
  58. @TASK_UTILS.register_module()
  59. class CenterRegionAssigner(BaseAssigner):
  60. """Assign pixels at the center region of a bbox as positive.
  61. Each proposals will be assigned with `-1`, `0`, or a positive integer
  62. indicating the ground truth index.
  63. - -1: negative samples
  64. - semi-positive numbers: positive sample, index (0-based) of assigned gt
  65. Args:
  66. pos_scale (float): Threshold within which pixels are
  67. labelled as positive.
  68. neg_scale (float): Threshold above which pixels are
  69. labelled as positive.
  70. min_pos_iof (float): Minimum iof of a pixel with a gt to be
  71. labelled as positive. Default: 1e-2
  72. ignore_gt_scale (float): Threshold within which the pixels
  73. are ignored when the gt is labelled as shadowed. Default: 0.5
  74. foreground_dominate (bool): If True, the bbox will be assigned as
  75. positive when a gt's kernel region overlaps with another's shadowed
  76. (ignored) region, otherwise it is set as ignored. Default to False.
  77. iou_calculator (:obj:`ConfigDict` or dict): Config of overlaps
  78. Calculator.
  79. """
  80. def __init__(
  81. self,
  82. pos_scale: float,
  83. neg_scale: float,
  84. min_pos_iof: float = 1e-2,
  85. ignore_gt_scale: float = 0.5,
  86. foreground_dominate: bool = False,
  87. iou_calculator: ConfigType = dict(type='BboxOverlaps2D')
  88. ) -> None:
  89. self.pos_scale = pos_scale
  90. self.neg_scale = neg_scale
  91. self.min_pos_iof = min_pos_iof
  92. self.ignore_gt_scale = ignore_gt_scale
  93. self.foreground_dominate = foreground_dominate
  94. self.iou_calculator = TASK_UTILS.build(iou_calculator)
  95. def get_gt_priorities(self, gt_bboxes: Tensor) -> Tensor:
  96. """Get gt priorities according to their areas.
  97. Smaller gt has higher priority.
  98. Args:
  99. gt_bboxes (Tensor): Ground truth boxes, shape (k, 4).
  100. Returns:
  101. Tensor: The priority of gts so that gts with larger priority is
  102. more likely to be assigned. Shape (k, )
  103. """
  104. gt_areas = bboxes_area(gt_bboxes)
  105. # Rank all gt bbox areas. Smaller objects has larger priority
  106. _, sort_idx = gt_areas.sort(descending=True)
  107. sort_idx = sort_idx.argsort()
  108. return sort_idx
  109. def assign(self,
  110. pred_instances: InstanceData,
  111. gt_instances: InstanceData,
  112. gt_instances_ignore: Optional[InstanceData] = None,
  113. **kwargs) -> AssignResult:
  114. """Assign gt to bboxes.
  115. This method assigns gts to every prior (proposal/anchor), each prior
  116. will be assigned with -1, or a semi-positive number. -1 means
  117. negative sample, semi-positive number is the index (0-based) of
  118. assigned gt.
  119. Args:
  120. pred_instances (:obj:`InstanceData`): Instances of model
  121. predictions. It includes ``priors``, and the priors can
  122. be anchors or points, or the bboxes predicted by the
  123. previous stage, has shape (n, 4). The bboxes predicted by
  124. the current model or stage will be named ``bboxes``,
  125. ``labels``, and ``scores``, the same as the ``InstanceData``
  126. in other places.
  127. gt_instances (:obj:`InstanceData`): Ground truth of instance
  128. annotations. It usually includes ``bboxes``, with shape (k, 4),
  129. and ``labels``, with shape (k, ).
  130. gt_instances_ignore (:obj:`InstanceData`, optional): Instances
  131. to be ignored during training. It includes ``bboxes``
  132. attribute data that is ignored during training and testing.
  133. Defaults to None.
  134. Returns:
  135. :obj:`AssignResult`: The assigned result. Note that shadowed_labels
  136. of shape (N, 2) is also added as an `assign_result` attribute.
  137. `shadowed_labels` is a tensor composed of N pairs of anchor_ind,
  138. class_label], where N is the number of anchors that lie in the
  139. outer region of a gt, anchor_ind is the shadowed anchor index
  140. and class_label is the shadowed class label.
  141. Example:
  142. >>> from mmengine.structures import InstanceData
  143. >>> self = CenterRegionAssigner(0.2, 0.2)
  144. >>> pred_instances.priors = torch.Tensor([[0, 0, 10, 10],
  145. ... [10, 10, 20, 20]])
  146. >>> gt_instances = InstanceData()
  147. >>> gt_instances.bboxes = torch.Tensor([[0, 0, 10, 10]])
  148. >>> gt_instances.labels = torch.Tensor([0])
  149. >>> assign_result = self.assign(pred_instances, gt_instances)
  150. >>> expected_gt_inds = torch.LongTensor([1, 0])
  151. >>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
  152. """
  153. # There are in total 5 steps in the pixel assignment
  154. # 1. Find core (the center region, say inner 0.2)
  155. # and shadow (the relatively ourter part, say inner 0.2-0.5)
  156. # regions of every gt.
  157. # 2. Find all prior bboxes that lie in gt_core and gt_shadow regions
  158. # 3. Assign prior bboxes in gt_core with a one-hot id of the gt in
  159. # the image.
  160. # 3.1. For overlapping objects, the prior bboxes in gt_core is
  161. # assigned with the object with smallest area
  162. # 4. Assign prior bboxes with class label according to its gt id.
  163. # 4.1. Assign -1 to prior bboxes lying in shadowed gts
  164. # 4.2. Assign positive prior boxes with the corresponding label
  165. # 5. Find pixels lying in the shadow of an object and assign them with
  166. # background label, but set the loss weight of its corresponding
  167. # gt to zero.
  168. # TODO not extract bboxes in assign.
  169. gt_bboxes = gt_instances.bboxes
  170. priors = pred_instances.priors
  171. gt_labels = gt_instances.labels
  172. assert priors.size(1) == 4, 'priors must have size of 4'
  173. # 1. Find core positive and shadow region of every gt
  174. gt_core = scale_boxes(gt_bboxes, self.pos_scale)
  175. gt_shadow = scale_boxes(gt_bboxes, self.neg_scale)
  176. # 2. Find prior bboxes that lie in gt_core and gt_shadow regions
  177. prior_centers = (priors[:, 2:4] + priors[:, 0:2]) / 2
  178. # The center points lie within the gt boxes
  179. is_prior_in_gt = is_located_in(prior_centers, gt_bboxes)
  180. # Only calculate prior and gt_core IoF. This enables small prior bboxes
  181. # to match large gts
  182. prior_and_gt_core_overlaps = self.iou_calculator(
  183. priors, gt_core, mode='iof')
  184. # The center point of effective priors should be within the gt box
  185. is_prior_in_gt_core = is_prior_in_gt & (
  186. prior_and_gt_core_overlaps > self.min_pos_iof) # shape (n, k)
  187. is_prior_in_gt_shadow = (
  188. self.iou_calculator(priors, gt_shadow, mode='iof') >
  189. self.min_pos_iof)
  190. # Rule out center effective positive pixels
  191. is_prior_in_gt_shadow &= (~is_prior_in_gt_core)
  192. num_gts, num_priors = gt_bboxes.size(0), priors.size(0)
  193. if num_gts == 0 or num_priors == 0:
  194. # If no gts exist, assign all pixels to negative
  195. assigned_gt_ids = \
  196. is_prior_in_gt_core.new_zeros((num_priors,),
  197. dtype=torch.long)
  198. pixels_in_gt_shadow = assigned_gt_ids.new_empty((0, 2))
  199. else:
  200. # Step 3: assign a one-hot gt id to each pixel, and smaller objects
  201. # have high priority to assign the pixel.
  202. sort_idx = self.get_gt_priorities(gt_bboxes)
  203. assigned_gt_ids, pixels_in_gt_shadow = \
  204. self.assign_one_hot_gt_indices(is_prior_in_gt_core,
  205. is_prior_in_gt_shadow,
  206. gt_priority=sort_idx)
  207. if (gt_instances_ignore is not None
  208. and gt_instances_ignore.bboxes.numel() > 0):
  209. # No ground truth or boxes, return empty assignment
  210. gt_bboxes_ignore = gt_instances_ignore.bboxes
  211. gt_bboxes_ignore = scale_boxes(
  212. gt_bboxes_ignore, scale=self.ignore_gt_scale)
  213. is_prior_in_ignored_gts = is_located_in(prior_centers,
  214. gt_bboxes_ignore)
  215. is_prior_in_ignored_gts = is_prior_in_ignored_gts.any(dim=1)
  216. assigned_gt_ids[is_prior_in_ignored_gts] = -1
  217. # 4. Assign prior bboxes with class label according to its gt id.
  218. # Default assigned label is the background (-1)
  219. assigned_labels = assigned_gt_ids.new_full((num_priors, ), -1)
  220. pos_inds = torch.nonzero(assigned_gt_ids > 0, as_tuple=False).squeeze()
  221. if pos_inds.numel() > 0:
  222. assigned_labels[pos_inds] = gt_labels[assigned_gt_ids[pos_inds] -
  223. 1]
  224. # 5. Find pixels lying in the shadow of an object
  225. shadowed_pixel_labels = pixels_in_gt_shadow.clone()
  226. if pixels_in_gt_shadow.numel() > 0:
  227. pixel_idx, gt_idx =\
  228. pixels_in_gt_shadow[:, 0], pixels_in_gt_shadow[:, 1]
  229. assert (assigned_gt_ids[pixel_idx] != gt_idx).all(), \
  230. 'Some pixels are dually assigned to ignore and gt!'
  231. shadowed_pixel_labels[:, 1] = gt_labels[gt_idx - 1]
  232. override = (
  233. assigned_labels[pixel_idx] == shadowed_pixel_labels[:, 1])
  234. if self.foreground_dominate:
  235. # When a pixel is both positive and shadowed, set it as pos
  236. shadowed_pixel_labels = shadowed_pixel_labels[~override]
  237. else:
  238. # When a pixel is both pos and shadowed, set it as shadowed
  239. assigned_labels[pixel_idx[override]] = -1
  240. assigned_gt_ids[pixel_idx[override]] = 0
  241. assign_result = AssignResult(
  242. num_gts, assigned_gt_ids, None, labels=assigned_labels)
  243. # Add shadowed_labels as assign_result property. Shape: (num_shadow, 2)
  244. assign_result.set_extra_property('shadowed_labels',
  245. shadowed_pixel_labels)
  246. return assign_result
  247. def assign_one_hot_gt_indices(
  248. self,
  249. is_prior_in_gt_core: Tensor,
  250. is_prior_in_gt_shadow: Tensor,
  251. gt_priority: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
  252. """Assign only one gt index to each prior box.
  253. Gts with large gt_priority are more likely to be assigned.
  254. Args:
  255. is_prior_in_gt_core (Tensor): Bool tensor indicating the prior
  256. center is in the core area of a gt (e.g. 0-0.2).
  257. Shape: (num_prior, num_gt).
  258. is_prior_in_gt_shadow (Tensor): Bool tensor indicating the prior
  259. center is in the shadowed area of a gt (e.g. 0.2-0.5).
  260. Shape: (num_prior, num_gt).
  261. gt_priority (Tensor): Priorities of gts. The gt with a higher
  262. priority is more likely to be assigned to the bbox when the
  263. bbox match with multiple gts. Shape: (num_gt, ).
  264. Returns:
  265. tuple: Returns (assigned_gt_inds, shadowed_gt_inds).
  266. - assigned_gt_inds: The assigned gt index of each prior bbox \
  267. (i.e. index from 1 to num_gts). Shape: (num_prior, ).
  268. - shadowed_gt_inds: shadowed gt indices. It is a tensor of \
  269. shape (num_ignore, 2) with first column being the shadowed prior \
  270. bbox indices and the second column the shadowed gt \
  271. indices (1-based).
  272. """
  273. num_bboxes, num_gts = is_prior_in_gt_core.shape
  274. if gt_priority is None:
  275. gt_priority = torch.arange(
  276. num_gts, device=is_prior_in_gt_core.device)
  277. assert gt_priority.size(0) == num_gts
  278. # The bigger gt_priority, the more preferable to be assigned
  279. # The assigned inds are by default 0 (background)
  280. assigned_gt_inds = is_prior_in_gt_core.new_zeros((num_bboxes, ),
  281. dtype=torch.long)
  282. # Shadowed bboxes are assigned to be background. But the corresponding
  283. # label is ignored during loss calculation, which is done through
  284. # shadowed_gt_inds
  285. shadowed_gt_inds = torch.nonzero(is_prior_in_gt_shadow, as_tuple=False)
  286. if is_prior_in_gt_core.sum() == 0: # No gt match
  287. shadowed_gt_inds[:, 1] += 1 # 1-based. For consistency issue
  288. return assigned_gt_inds, shadowed_gt_inds
  289. # The priority of each prior box and gt pair. If one prior box is
  290. # matched bo multiple gts. Only the pair with the highest priority
  291. # is saved
  292. pair_priority = is_prior_in_gt_core.new_full((num_bboxes, num_gts),
  293. -1,
  294. dtype=torch.long)
  295. # Each bbox could match with multiple gts.
  296. # The following codes deal with this situation
  297. # Matched bboxes (to any gt). Shape: (num_pos_anchor, )
  298. inds_of_match = torch.any(is_prior_in_gt_core, dim=1)
  299. # The matched gt index of each positive bbox. Length >= num_pos_anchor
  300. # , since one bbox could match multiple gts
  301. matched_bbox_gt_inds = torch.nonzero(
  302. is_prior_in_gt_core, as_tuple=False)[:, 1]
  303. # Assign priority to each bbox-gt pair.
  304. pair_priority[is_prior_in_gt_core] = gt_priority[matched_bbox_gt_inds]
  305. _, argmax_priority = pair_priority[inds_of_match].max(dim=1)
  306. assigned_gt_inds[inds_of_match] = argmax_priority + 1 # 1-based
  307. # Zero-out the assigned anchor box to filter the shadowed gt indices
  308. is_prior_in_gt_core[inds_of_match, argmax_priority] = 0
  309. # Concat the shadowed indices due to overlapping with that out side of
  310. # effective scale. shape: (total_num_ignore, 2)
  311. shadowed_gt_inds = torch.cat(
  312. (shadowed_gt_inds,
  313. torch.nonzero(is_prior_in_gt_core, as_tuple=False)),
  314. dim=0)
  315. # Change `is_prior_in_gt_core` back to keep arguments intact.
  316. is_prior_in_gt_core[inds_of_match, argmax_priority] = 1
  317. # 1-based shadowed gt indices, to be consistent with `assigned_gt_inds`
  318. if shadowed_gt_inds.numel() > 0:
  319. shadowed_gt_inds[:, 1] += 1
  320. return assigned_gt_inds, shadowed_gt_inds