bucketing_bbox_coder.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Sequence, Tuple, Union
  3. import numpy as np
  4. import torch
  5. import torch.nn.functional as F
  6. from torch import Tensor
  7. from mmdet.registry import TASK_UTILS
  8. from mmdet.structures.bbox import (BaseBoxes, HorizontalBoxes, bbox_rescale,
  9. get_box_tensor)
  10. from .base_bbox_coder import BaseBBoxCoder
  11. @TASK_UTILS.register_module()
  12. class BucketingBBoxCoder(BaseBBoxCoder):
  13. """Bucketing BBox Coder for Side-Aware Boundary Localization (SABL).
  14. Boundary Localization with Bucketing and Bucketing Guided Rescoring
  15. are implemented here.
  16. Please refer to https://arxiv.org/abs/1912.04260 for more details.
  17. Args:
  18. num_buckets (int): Number of buckets.
  19. scale_factor (int): Scale factor of proposals to generate buckets.
  20. offset_topk (int): Topk buckets are used to generate
  21. bucket fine regression targets. Defaults to 2.
  22. offset_upperbound (float): Offset upperbound to generate
  23. bucket fine regression targets.
  24. To avoid too large offset displacements. Defaults to 1.0.
  25. cls_ignore_neighbor (bool): Ignore second nearest bucket or Not.
  26. Defaults to True.
  27. clip_border (bool, optional): Whether clip the objects outside the
  28. border of the image. Defaults to True.
  29. """
  30. def __init__(self,
  31. num_buckets: int,
  32. scale_factor: int,
  33. offset_topk: int = 2,
  34. offset_upperbound: float = 1.0,
  35. cls_ignore_neighbor: bool = True,
  36. clip_border: bool = True,
  37. **kwargs) -> None:
  38. super().__init__(**kwargs)
  39. self.num_buckets = num_buckets
  40. self.scale_factor = scale_factor
  41. self.offset_topk = offset_topk
  42. self.offset_upperbound = offset_upperbound
  43. self.cls_ignore_neighbor = cls_ignore_neighbor
  44. self.clip_border = clip_border
  45. def encode(self, bboxes: Union[Tensor, BaseBoxes],
  46. gt_bboxes: Union[Tensor, BaseBoxes]) -> Tuple[Tensor]:
  47. """Get bucketing estimation and fine regression targets during
  48. training.
  49. Args:
  50. bboxes (torch.Tensor or :obj:`BaseBoxes`): source boxes,
  51. e.g., object proposals.
  52. gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): target of the
  53. transformation, e.g., ground truth boxes.
  54. Returns:
  55. encoded_bboxes(tuple[Tensor]): bucketing estimation
  56. and fine regression targets and weights
  57. """
  58. bboxes = get_box_tensor(bboxes)
  59. gt_bboxes = get_box_tensor(gt_bboxes)
  60. assert bboxes.size(0) == gt_bboxes.size(0)
  61. assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
  62. encoded_bboxes = bbox2bucket(bboxes, gt_bboxes, self.num_buckets,
  63. self.scale_factor, self.offset_topk,
  64. self.offset_upperbound,
  65. self.cls_ignore_neighbor)
  66. return encoded_bboxes
  67. def decode(
  68. self,
  69. bboxes: Union[Tensor, BaseBoxes],
  70. pred_bboxes: Tensor,
  71. max_shape: Optional[Tuple[int]] = None
  72. ) -> Tuple[Union[Tensor, BaseBoxes], Tensor]:
  73. """Apply transformation `pred_bboxes` to `boxes`.
  74. Args:
  75. boxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes.
  76. pred_bboxes (torch.Tensor): Predictions for bucketing estimation
  77. and fine regression
  78. max_shape (tuple[int], optional): Maximum shape of boxes.
  79. Defaults to None.
  80. Returns:
  81. Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes.
  82. """
  83. bboxes = get_box_tensor(bboxes)
  84. assert len(pred_bboxes) == 2
  85. cls_preds, offset_preds = pred_bboxes
  86. assert cls_preds.size(0) == bboxes.size(0) and offset_preds.size(
  87. 0) == bboxes.size(0)
  88. bboxes, loc_confidence = bucket2bbox(bboxes, cls_preds, offset_preds,
  89. self.num_buckets,
  90. self.scale_factor, max_shape,
  91. self.clip_border)
  92. if self.use_box_type:
  93. bboxes = HorizontalBoxes(bboxes, clone=False)
  94. return bboxes, loc_confidence
  95. def generat_buckets(proposals: Tensor,
  96. num_buckets: int,
  97. scale_factor: float = 1.0) -> Tuple[Tensor]:
  98. """Generate buckets w.r.t bucket number and scale factor of proposals.
  99. Args:
  100. proposals (Tensor): Shape (n, 4)
  101. num_buckets (int): Number of buckets.
  102. scale_factor (float): Scale factor to rescale proposals.
  103. Returns:
  104. tuple[Tensor]: (bucket_w, bucket_h, l_buckets, r_buckets,
  105. t_buckets, d_buckets)
  106. - bucket_w: Width of buckets on x-axis. Shape (n, ).
  107. - bucket_h: Height of buckets on y-axis. Shape (n, ).
  108. - l_buckets: Left buckets. Shape (n, ceil(side_num/2)).
  109. - r_buckets: Right buckets. Shape (n, ceil(side_num/2)).
  110. - t_buckets: Top buckets. Shape (n, ceil(side_num/2)).
  111. - d_buckets: Down buckets. Shape (n, ceil(side_num/2)).
  112. """
  113. proposals = bbox_rescale(proposals, scale_factor)
  114. # number of buckets in each side
  115. side_num = int(np.ceil(num_buckets / 2.0))
  116. pw = proposals[..., 2] - proposals[..., 0]
  117. ph = proposals[..., 3] - proposals[..., 1]
  118. px1 = proposals[..., 0]
  119. py1 = proposals[..., 1]
  120. px2 = proposals[..., 2]
  121. py2 = proposals[..., 3]
  122. bucket_w = pw / num_buckets
  123. bucket_h = ph / num_buckets
  124. # left buckets
  125. l_buckets = px1[:, None] + (0.5 + torch.arange(
  126. 0, side_num).to(proposals).float())[None, :] * bucket_w[:, None]
  127. # right buckets
  128. r_buckets = px2[:, None] - (0.5 + torch.arange(
  129. 0, side_num).to(proposals).float())[None, :] * bucket_w[:, None]
  130. # top buckets
  131. t_buckets = py1[:, None] + (0.5 + torch.arange(
  132. 0, side_num).to(proposals).float())[None, :] * bucket_h[:, None]
  133. # down buckets
  134. d_buckets = py2[:, None] - (0.5 + torch.arange(
  135. 0, side_num).to(proposals).float())[None, :] * bucket_h[:, None]
  136. return bucket_w, bucket_h, l_buckets, r_buckets, t_buckets, d_buckets
  137. def bbox2bucket(proposals: Tensor,
  138. gt: Tensor,
  139. num_buckets: int,
  140. scale_factor: float,
  141. offset_topk: int = 2,
  142. offset_upperbound: float = 1.0,
  143. cls_ignore_neighbor: bool = True) -> Tuple[Tensor]:
  144. """Generate buckets estimation and fine regression targets.
  145. Args:
  146. proposals (Tensor): Shape (n, 4)
  147. gt (Tensor): Shape (n, 4)
  148. num_buckets (int): Number of buckets.
  149. scale_factor (float): Scale factor to rescale proposals.
  150. offset_topk (int): Topk buckets are used to generate
  151. bucket fine regression targets. Defaults to 2.
  152. offset_upperbound (float): Offset allowance to generate
  153. bucket fine regression targets.
  154. To avoid too large offset displacements. Defaults to 1.0.
  155. cls_ignore_neighbor (bool): Ignore second nearest bucket or Not.
  156. Defaults to True.
  157. Returns:
  158. tuple[Tensor]: (offsets, offsets_weights, bucket_labels, cls_weights).
  159. - offsets: Fine regression targets. \
  160. Shape (n, num_buckets*2).
  161. - offsets_weights: Fine regression weights. \
  162. Shape (n, num_buckets*2).
  163. - bucket_labels: Bucketing estimation labels. \
  164. Shape (n, num_buckets*2).
  165. - cls_weights: Bucketing estimation weights. \
  166. Shape (n, num_buckets*2).
  167. """
  168. assert proposals.size() == gt.size()
  169. # generate buckets
  170. proposals = proposals.float()
  171. gt = gt.float()
  172. (bucket_w, bucket_h, l_buckets, r_buckets, t_buckets,
  173. d_buckets) = generat_buckets(proposals, num_buckets, scale_factor)
  174. gx1 = gt[..., 0]
  175. gy1 = gt[..., 1]
  176. gx2 = gt[..., 2]
  177. gy2 = gt[..., 3]
  178. # generate offset targets and weights
  179. # offsets from buckets to gts
  180. l_offsets = (l_buckets - gx1[:, None]) / bucket_w[:, None]
  181. r_offsets = (r_buckets - gx2[:, None]) / bucket_w[:, None]
  182. t_offsets = (t_buckets - gy1[:, None]) / bucket_h[:, None]
  183. d_offsets = (d_buckets - gy2[:, None]) / bucket_h[:, None]
  184. # select top-k nearest buckets
  185. l_topk, l_label = l_offsets.abs().topk(
  186. offset_topk, dim=1, largest=False, sorted=True)
  187. r_topk, r_label = r_offsets.abs().topk(
  188. offset_topk, dim=1, largest=False, sorted=True)
  189. t_topk, t_label = t_offsets.abs().topk(
  190. offset_topk, dim=1, largest=False, sorted=True)
  191. d_topk, d_label = d_offsets.abs().topk(
  192. offset_topk, dim=1, largest=False, sorted=True)
  193. offset_l_weights = l_offsets.new_zeros(l_offsets.size())
  194. offset_r_weights = r_offsets.new_zeros(r_offsets.size())
  195. offset_t_weights = t_offsets.new_zeros(t_offsets.size())
  196. offset_d_weights = d_offsets.new_zeros(d_offsets.size())
  197. inds = torch.arange(0, proposals.size(0)).to(proposals).long()
  198. # generate offset weights of top-k nearest buckets
  199. for k in range(offset_topk):
  200. if k >= 1:
  201. offset_l_weights[inds, l_label[:,
  202. k]] = (l_topk[:, k] <
  203. offset_upperbound).float()
  204. offset_r_weights[inds, r_label[:,
  205. k]] = (r_topk[:, k] <
  206. offset_upperbound).float()
  207. offset_t_weights[inds, t_label[:,
  208. k]] = (t_topk[:, k] <
  209. offset_upperbound).float()
  210. offset_d_weights[inds, d_label[:,
  211. k]] = (d_topk[:, k] <
  212. offset_upperbound).float()
  213. else:
  214. offset_l_weights[inds, l_label[:, k]] = 1.0
  215. offset_r_weights[inds, r_label[:, k]] = 1.0
  216. offset_t_weights[inds, t_label[:, k]] = 1.0
  217. offset_d_weights[inds, d_label[:, k]] = 1.0
  218. offsets = torch.cat([l_offsets, r_offsets, t_offsets, d_offsets], dim=-1)
  219. offsets_weights = torch.cat([
  220. offset_l_weights, offset_r_weights, offset_t_weights, offset_d_weights
  221. ],
  222. dim=-1)
  223. # generate bucket labels and weight
  224. side_num = int(np.ceil(num_buckets / 2.0))
  225. labels = torch.stack(
  226. [l_label[:, 0], r_label[:, 0], t_label[:, 0], d_label[:, 0]], dim=-1)
  227. batch_size = labels.size(0)
  228. bucket_labels = F.one_hot(labels.view(-1), side_num).view(batch_size,
  229. -1).float()
  230. bucket_cls_l_weights = (l_offsets.abs() < 1).float()
  231. bucket_cls_r_weights = (r_offsets.abs() < 1).float()
  232. bucket_cls_t_weights = (t_offsets.abs() < 1).float()
  233. bucket_cls_d_weights = (d_offsets.abs() < 1).float()
  234. bucket_cls_weights = torch.cat([
  235. bucket_cls_l_weights, bucket_cls_r_weights, bucket_cls_t_weights,
  236. bucket_cls_d_weights
  237. ],
  238. dim=-1)
  239. # ignore second nearest buckets for cls if necessary
  240. if cls_ignore_neighbor:
  241. bucket_cls_weights = (~((bucket_cls_weights == 1) &
  242. (bucket_labels == 0))).float()
  243. else:
  244. bucket_cls_weights[:] = 1.0
  245. return offsets, offsets_weights, bucket_labels, bucket_cls_weights
  246. def bucket2bbox(proposals: Tensor,
  247. cls_preds: Tensor,
  248. offset_preds: Tensor,
  249. num_buckets: int,
  250. scale_factor: float = 1.0,
  251. max_shape: Optional[Union[Sequence[int], Tensor,
  252. Sequence[Sequence[int]]]] = None,
  253. clip_border: bool = True) -> Tuple[Tensor]:
  254. """Apply bucketing estimation (cls preds) and fine regression (offset
  255. preds) to generate det bboxes.
  256. Args:
  257. proposals (Tensor): Boxes to be transformed. Shape (n, 4)
  258. cls_preds (Tensor): bucketing estimation. Shape (n, num_buckets*2).
  259. offset_preds (Tensor): fine regression. Shape (n, num_buckets*2).
  260. num_buckets (int): Number of buckets.
  261. scale_factor (float): Scale factor to rescale proposals.
  262. max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W)
  263. clip_border (bool, optional): Whether clip the objects outside the
  264. border of the image. Defaults to True.
  265. Returns:
  266. tuple[Tensor]: (bboxes, loc_confidence).
  267. - bboxes: predicted bboxes. Shape (n, 4)
  268. - loc_confidence: localization confidence of predicted bboxes.
  269. Shape (n,).
  270. """
  271. side_num = int(np.ceil(num_buckets / 2.0))
  272. cls_preds = cls_preds.view(-1, side_num)
  273. offset_preds = offset_preds.view(-1, side_num)
  274. scores = F.softmax(cls_preds, dim=1)
  275. score_topk, score_label = scores.topk(2, dim=1, largest=True, sorted=True)
  276. rescaled_proposals = bbox_rescale(proposals, scale_factor)
  277. pw = rescaled_proposals[..., 2] - rescaled_proposals[..., 0]
  278. ph = rescaled_proposals[..., 3] - rescaled_proposals[..., 1]
  279. px1 = rescaled_proposals[..., 0]
  280. py1 = rescaled_proposals[..., 1]
  281. px2 = rescaled_proposals[..., 2]
  282. py2 = rescaled_proposals[..., 3]
  283. bucket_w = pw / num_buckets
  284. bucket_h = ph / num_buckets
  285. score_inds_l = score_label[0::4, 0]
  286. score_inds_r = score_label[1::4, 0]
  287. score_inds_t = score_label[2::4, 0]
  288. score_inds_d = score_label[3::4, 0]
  289. l_buckets = px1 + (0.5 + score_inds_l.float()) * bucket_w
  290. r_buckets = px2 - (0.5 + score_inds_r.float()) * bucket_w
  291. t_buckets = py1 + (0.5 + score_inds_t.float()) * bucket_h
  292. d_buckets = py2 - (0.5 + score_inds_d.float()) * bucket_h
  293. offsets = offset_preds.view(-1, 4, side_num)
  294. inds = torch.arange(proposals.size(0)).to(proposals).long()
  295. l_offsets = offsets[:, 0, :][inds, score_inds_l]
  296. r_offsets = offsets[:, 1, :][inds, score_inds_r]
  297. t_offsets = offsets[:, 2, :][inds, score_inds_t]
  298. d_offsets = offsets[:, 3, :][inds, score_inds_d]
  299. x1 = l_buckets - l_offsets * bucket_w
  300. x2 = r_buckets - r_offsets * bucket_w
  301. y1 = t_buckets - t_offsets * bucket_h
  302. y2 = d_buckets - d_offsets * bucket_h
  303. if clip_border and max_shape is not None:
  304. x1 = x1.clamp(min=0, max=max_shape[1] - 1)
  305. y1 = y1.clamp(min=0, max=max_shape[0] - 1)
  306. x2 = x2.clamp(min=0, max=max_shape[1] - 1)
  307. y2 = y2.clamp(min=0, max=max_shape[0] - 1)
  308. bboxes = torch.cat([x1[:, None], y1[:, None], x2[:, None], y2[:, None]],
  309. dim=-1)
  310. # bucketing guided rescoring
  311. loc_confidence = score_topk[:, 0]
  312. top2_neighbor_inds = (score_label[:, 0] - score_label[:, 1]).abs() == 1
  313. loc_confidence += score_topk[:, 1] * top2_neighbor_inds.float()
  314. loc_confidence = loc_confidence.view(-1, 4).mean(dim=1)
  315. return bboxes, loc_confidence