delta_xywh_bbox_coder.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. from typing import Optional, Sequence, Union
  4. import numpy as np
  5. import torch
  6. from torch import Tensor
  7. from mmdet.registry import TASK_UTILS
  8. from mmdet.structures.bbox import BaseBoxes, HorizontalBoxes, get_box_tensor
  9. from .base_bbox_coder import BaseBBoxCoder
  10. @TASK_UTILS.register_module()
  11. class DeltaXYWHBBoxCoder(BaseBBoxCoder):
  12. """Delta XYWH BBox coder.
  13. Following the practice in `R-CNN <https://arxiv.org/abs/1311.2524>`_,
  14. this coder encodes bbox (x1, y1, x2, y2) into delta (dx, dy, dw, dh) and
  15. decodes delta (dx, dy, dw, dh) back to original bbox (x1, y1, x2, y2).
  16. Args:
  17. target_means (Sequence[float]): Denormalizing means of target for
  18. delta coordinates
  19. target_stds (Sequence[float]): Denormalizing standard deviation of
  20. target for delta coordinates
  21. clip_border (bool, optional): Whether clip the objects outside the
  22. border of the image. Defaults to True.
  23. add_ctr_clamp (bool): Whether to add center clamp, when added, the
  24. predicted box is clamped is its center is too far away from
  25. the original anchor's center. Only used by YOLOF. Default False.
  26. ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
  27. Default 32.
  28. """
  29. def __init__(self,
  30. target_means: Sequence[float] = (0., 0., 0., 0.),
  31. target_stds: Sequence[float] = (1., 1., 1., 1.),
  32. clip_border: bool = True,
  33. add_ctr_clamp: bool = False,
  34. ctr_clamp: int = 32,
  35. **kwargs) -> None:
  36. super().__init__(**kwargs)
  37. self.means = target_means
  38. self.stds = target_stds
  39. self.clip_border = clip_border
  40. self.add_ctr_clamp = add_ctr_clamp
  41. self.ctr_clamp = ctr_clamp
  42. def encode(self, bboxes: Union[Tensor, BaseBoxes],
  43. gt_bboxes: Union[Tensor, BaseBoxes]) -> Tensor:
  44. """Get box regression transformation deltas that can be used to
  45. transform the ``bboxes`` into the ``gt_bboxes``.
  46. Args:
  47. bboxes (torch.Tensor or :obj:`BaseBoxes`): Source boxes,
  48. e.g., object proposals.
  49. gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): Target of the
  50. transformation, e.g., ground-truth boxes.
  51. Returns:
  52. torch.Tensor: Box transformation deltas
  53. """
  54. bboxes = get_box_tensor(bboxes)
  55. gt_bboxes = get_box_tensor(gt_bboxes)
  56. assert bboxes.size(0) == gt_bboxes.size(0)
  57. assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
  58. encoded_bboxes = bbox2delta(bboxes, gt_bboxes, self.means, self.stds)
  59. return encoded_bboxes
  60. def decode(
  61. self,
  62. bboxes: Union[Tensor, BaseBoxes],
  63. pred_bboxes: Tensor,
  64. max_shape: Optional[Union[Sequence[int], Tensor,
  65. Sequence[Sequence[int]]]] = None,
  66. wh_ratio_clip: Optional[float] = 16 / 1000
  67. ) -> Union[Tensor, BaseBoxes]:
  68. """Apply transformation `pred_bboxes` to `boxes`.
  69. Args:
  70. bboxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes. Shape
  71. (B, N, 4) or (N, 4)
  72. pred_bboxes (Tensor): Encoded offsets with respect to each roi.
  73. Has shape (B, N, num_classes * 4) or (B, N, 4) or
  74. (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
  75. when rois is a grid of anchors.Offset encoding follows [1]_.
  76. max_shape (Sequence[int] or torch.Tensor or Sequence[
  77. Sequence[int]],optional): Maximum bounds for boxes, specifies
  78. (H, W, C) or (H, W). If bboxes shape is (B, N, 4), then
  79. the max_shape should be a Sequence[Sequence[int]]
  80. and the length of max_shape should also be B.
  81. wh_ratio_clip (float, optional): The allowed ratio between
  82. width and height.
  83. Returns:
  84. Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes.
  85. """
  86. bboxes = get_box_tensor(bboxes)
  87. assert pred_bboxes.size(0) == bboxes.size(0)
  88. if pred_bboxes.ndim == 3:
  89. assert pred_bboxes.size(1) == bboxes.size(1)
  90. if pred_bboxes.ndim == 2 and not torch.onnx.is_in_onnx_export():
  91. # single image decode
  92. decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means,
  93. self.stds, max_shape, wh_ratio_clip,
  94. self.clip_border, self.add_ctr_clamp,
  95. self.ctr_clamp)
  96. else:
  97. if pred_bboxes.ndim == 3 and not torch.onnx.is_in_onnx_export():
  98. warnings.warn(
  99. 'DeprecationWarning: onnx_delta2bbox is deprecated '
  100. 'in the case of batch decoding and non-ONNX, '
  101. 'please use “delta2bbox” instead. In order to improve '
  102. 'the decoding speed, the batch function will no '
  103. 'longer be supported. ')
  104. decoded_bboxes = onnx_delta2bbox(bboxes, pred_bboxes, self.means,
  105. self.stds, max_shape,
  106. wh_ratio_clip, self.clip_border,
  107. self.add_ctr_clamp,
  108. self.ctr_clamp)
  109. if self.use_box_type:
  110. assert decoded_bboxes.size(-1) == 4, \
  111. ('Cannot warp decoded boxes with box type when decoded boxes'
  112. 'have shape of (N, num_classes * 4)')
  113. decoded_bboxes = HorizontalBoxes(decoded_bboxes)
  114. return decoded_bboxes
  115. def bbox2delta(
  116. proposals: Tensor,
  117. gt: Tensor,
  118. means: Sequence[float] = (0., 0., 0., 0.),
  119. stds: Sequence[float] = (1., 1., 1., 1.)
  120. ) -> Tensor:
  121. """Compute deltas of proposals w.r.t. gt.
  122. We usually compute the deltas of x, y, w, h of proposals w.r.t ground
  123. truth bboxes to get regression target.
  124. This is the inverse function of :func:`delta2bbox`.
  125. Args:
  126. proposals (Tensor): Boxes to be transformed, shape (N, ..., 4)
  127. gt (Tensor): Gt bboxes to be used as base, shape (N, ..., 4)
  128. means (Sequence[float]): Denormalizing means for delta coordinates
  129. stds (Sequence[float]): Denormalizing standard deviation for delta
  130. coordinates
  131. Returns:
  132. Tensor: deltas with shape (N, 4), where columns represent dx, dy,
  133. dw, dh.
  134. """
  135. assert proposals.size() == gt.size()
  136. proposals = proposals.float()
  137. gt = gt.float()
  138. px = (proposals[..., 0] + proposals[..., 2]) * 0.5
  139. py = (proposals[..., 1] + proposals[..., 3]) * 0.5
  140. pw = proposals[..., 2] - proposals[..., 0]
  141. ph = proposals[..., 3] - proposals[..., 1]
  142. gx = (gt[..., 0] + gt[..., 2]) * 0.5
  143. gy = (gt[..., 1] + gt[..., 3]) * 0.5
  144. gw = gt[..., 2] - gt[..., 0]
  145. gh = gt[..., 3] - gt[..., 1]
  146. dx = (gx - px) / pw
  147. dy = (gy - py) / ph
  148. dw = torch.log(gw / pw)
  149. dh = torch.log(gh / ph)
  150. deltas = torch.stack([dx, dy, dw, dh], dim=-1)
  151. means = deltas.new_tensor(means).unsqueeze(0)
  152. stds = deltas.new_tensor(stds).unsqueeze(0)
  153. deltas = deltas.sub_(means).div_(stds)
  154. return deltas
  155. def delta2bbox(rois: Tensor,
  156. deltas: Tensor,
  157. means: Sequence[float] = (0., 0., 0., 0.),
  158. stds: Sequence[float] = (1., 1., 1., 1.),
  159. max_shape: Optional[Union[Sequence[int], Tensor,
  160. Sequence[Sequence[int]]]] = None,
  161. wh_ratio_clip: float = 16 / 1000,
  162. clip_border: bool = True,
  163. add_ctr_clamp: bool = False,
  164. ctr_clamp: int = 32) -> Tensor:
  165. """Apply deltas to shift/scale base boxes.
  166. Typically the rois are anchor or proposed bounding boxes and the deltas are
  167. network outputs used to shift/scale those boxes.
  168. This is the inverse function of :func:`bbox2delta`.
  169. Args:
  170. rois (Tensor): Boxes to be transformed. Has shape (N, 4).
  171. deltas (Tensor): Encoded offsets relative to each roi.
  172. Has shape (N, num_classes * 4) or (N, 4). Note
  173. N = num_base_anchors * W * H, when rois is a grid of
  174. anchors. Offset encoding follows [1]_.
  175. means (Sequence[float]): Denormalizing means for delta coordinates.
  176. Default (0., 0., 0., 0.).
  177. stds (Sequence[float]): Denormalizing standard deviation for delta
  178. coordinates. Default (1., 1., 1., 1.).
  179. max_shape (tuple[int, int]): Maximum bounds for boxes, specifies
  180. (H, W). Default None.
  181. wh_ratio_clip (float): Maximum aspect ratio for boxes. Default
  182. 16 / 1000.
  183. clip_border (bool, optional): Whether clip the objects outside the
  184. border of the image. Default True.
  185. add_ctr_clamp (bool): Whether to add center clamp. When set to True,
  186. the center of the prediction bounding box will be clamped to
  187. avoid being too far away from the center of the anchor.
  188. Only used by YOLOF. Default False.
  189. ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
  190. Default 32.
  191. Returns:
  192. Tensor: Boxes with shape (N, num_classes * 4) or (N, 4), where 4
  193. represent tl_x, tl_y, br_x, br_y.
  194. References:
  195. .. [1] https://arxiv.org/abs/1311.2524
  196. Example:
  197. >>> rois = torch.Tensor([[ 0., 0., 1., 1.],
  198. >>> [ 0., 0., 1., 1.],
  199. >>> [ 0., 0., 1., 1.],
  200. >>> [ 5., 5., 5., 5.]])
  201. >>> deltas = torch.Tensor([[ 0., 0., 0., 0.],
  202. >>> [ 1., 1., 1., 1.],
  203. >>> [ 0., 0., 2., -1.],
  204. >>> [ 0.7, -1.9, -0.5, 0.3]])
  205. >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3))
  206. tensor([[0.0000, 0.0000, 1.0000, 1.0000],
  207. [0.1409, 0.1409, 2.8591, 2.8591],
  208. [0.0000, 0.3161, 4.1945, 0.6839],
  209. [5.0000, 5.0000, 5.0000, 5.0000]])
  210. """
  211. num_bboxes, num_classes = deltas.size(0), deltas.size(1) // 4
  212. if num_bboxes == 0:
  213. return deltas
  214. deltas = deltas.reshape(-1, 4)
  215. means = deltas.new_tensor(means).view(1, -1)
  216. stds = deltas.new_tensor(stds).view(1, -1)
  217. denorm_deltas = deltas * stds + means
  218. dxy = denorm_deltas[:, :2]
  219. dwh = denorm_deltas[:, 2:]
  220. # Compute width/height of each roi
  221. rois_ = rois.repeat(1, num_classes).reshape(-1, 4)
  222. pxy = ((rois_[:, :2] + rois_[:, 2:]) * 0.5)
  223. pwh = (rois_[:, 2:] - rois_[:, :2])
  224. dxy_wh = pwh * dxy
  225. max_ratio = np.abs(np.log(wh_ratio_clip))
  226. if add_ctr_clamp:
  227. dxy_wh = torch.clamp(dxy_wh, max=ctr_clamp, min=-ctr_clamp)
  228. dwh = torch.clamp(dwh, max=max_ratio)
  229. else:
  230. dwh = dwh.clamp(min=-max_ratio, max=max_ratio)
  231. gxy = pxy + dxy_wh
  232. gwh = pwh * dwh.exp()
  233. x1y1 = gxy - (gwh * 0.5)
  234. x2y2 = gxy + (gwh * 0.5)
  235. bboxes = torch.cat([x1y1, x2y2], dim=-1)
  236. if clip_border and max_shape is not None:
  237. bboxes[..., 0::2].clamp_(min=0, max=max_shape[1])
  238. bboxes[..., 1::2].clamp_(min=0, max=max_shape[0])
  239. bboxes = bboxes.reshape(num_bboxes, -1)
  240. return bboxes
  241. def onnx_delta2bbox(rois: Tensor,
  242. deltas: Tensor,
  243. means: Sequence[float] = (0., 0., 0., 0.),
  244. stds: Sequence[float] = (1., 1., 1., 1.),
  245. max_shape: Optional[Union[Sequence[int], Tensor,
  246. Sequence[Sequence[int]]]] = None,
  247. wh_ratio_clip: float = 16 / 1000,
  248. clip_border: Optional[bool] = True,
  249. add_ctr_clamp: bool = False,
  250. ctr_clamp: int = 32) -> Tensor:
  251. """Apply deltas to shift/scale base boxes.
  252. Typically the rois are anchor or proposed bounding boxes and the deltas are
  253. network outputs used to shift/scale those boxes.
  254. This is the inverse function of :func:`bbox2delta`.
  255. Args:
  256. rois (Tensor): Boxes to be transformed. Has shape (N, 4) or (B, N, 4)
  257. deltas (Tensor): Encoded offsets with respect to each roi.
  258. Has shape (B, N, num_classes * 4) or (B, N, 4) or
  259. (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
  260. when rois is a grid of anchors.Offset encoding follows [1]_.
  261. means (Sequence[float]): Denormalizing means for delta coordinates.
  262. Default (0., 0., 0., 0.).
  263. stds (Sequence[float]): Denormalizing standard deviation for delta
  264. coordinates. Default (1., 1., 1., 1.).
  265. max_shape (Sequence[int] or torch.Tensor or Sequence[
  266. Sequence[int]],optional): Maximum bounds for boxes, specifies
  267. (H, W, C) or (H, W). If rois shape is (B, N, 4), then
  268. the max_shape should be a Sequence[Sequence[int]]
  269. and the length of max_shape should also be B. Default None.
  270. wh_ratio_clip (float): Maximum aspect ratio for boxes.
  271. Default 16 / 1000.
  272. clip_border (bool, optional): Whether clip the objects outside the
  273. border of the image. Default True.
  274. add_ctr_clamp (bool): Whether to add center clamp, when added, the
  275. predicted box is clamped is its center is too far away from
  276. the original anchor's center. Only used by YOLOF. Default False.
  277. ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
  278. Default 32.
  279. Returns:
  280. Tensor: Boxes with shape (B, N, num_classes * 4) or (B, N, 4) or
  281. (N, num_classes * 4) or (N, 4), where 4 represent
  282. tl_x, tl_y, br_x, br_y.
  283. References:
  284. .. [1] https://arxiv.org/abs/1311.2524
  285. Example:
  286. >>> rois = torch.Tensor([[ 0., 0., 1., 1.],
  287. >>> [ 0., 0., 1., 1.],
  288. >>> [ 0., 0., 1., 1.],
  289. >>> [ 5., 5., 5., 5.]])
  290. >>> deltas = torch.Tensor([[ 0., 0., 0., 0.],
  291. >>> [ 1., 1., 1., 1.],
  292. >>> [ 0., 0., 2., -1.],
  293. >>> [ 0.7, -1.9, -0.5, 0.3]])
  294. >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3))
  295. tensor([[0.0000, 0.0000, 1.0000, 1.0000],
  296. [0.1409, 0.1409, 2.8591, 2.8591],
  297. [0.0000, 0.3161, 4.1945, 0.6839],
  298. [5.0000, 5.0000, 5.0000, 5.0000]])
  299. """
  300. means = deltas.new_tensor(means).view(1,
  301. -1).repeat(1,
  302. deltas.size(-1) // 4)
  303. stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(-1) // 4)
  304. denorm_deltas = deltas * stds + means
  305. dx = denorm_deltas[..., 0::4]
  306. dy = denorm_deltas[..., 1::4]
  307. dw = denorm_deltas[..., 2::4]
  308. dh = denorm_deltas[..., 3::4]
  309. x1, y1 = rois[..., 0], rois[..., 1]
  310. x2, y2 = rois[..., 2], rois[..., 3]
  311. # Compute center of each roi
  312. px = ((x1 + x2) * 0.5).unsqueeze(-1).expand_as(dx)
  313. py = ((y1 + y2) * 0.5).unsqueeze(-1).expand_as(dy)
  314. # Compute width/height of each roi
  315. pw = (x2 - x1).unsqueeze(-1).expand_as(dw)
  316. ph = (y2 - y1).unsqueeze(-1).expand_as(dh)
  317. dx_width = pw * dx
  318. dy_height = ph * dy
  319. max_ratio = np.abs(np.log(wh_ratio_clip))
  320. if add_ctr_clamp:
  321. dx_width = torch.clamp(dx_width, max=ctr_clamp, min=-ctr_clamp)
  322. dy_height = torch.clamp(dy_height, max=ctr_clamp, min=-ctr_clamp)
  323. dw = torch.clamp(dw, max=max_ratio)
  324. dh = torch.clamp(dh, max=max_ratio)
  325. else:
  326. dw = dw.clamp(min=-max_ratio, max=max_ratio)
  327. dh = dh.clamp(min=-max_ratio, max=max_ratio)
  328. # Use exp(network energy) to enlarge/shrink each roi
  329. gw = pw * dw.exp()
  330. gh = ph * dh.exp()
  331. # Use network energy to shift the center of each roi
  332. gx = px + dx_width
  333. gy = py + dy_height
  334. # Convert center-xy/width/height to top-left, bottom-right
  335. x1 = gx - gw * 0.5
  336. y1 = gy - gh * 0.5
  337. x2 = gx + gw * 0.5
  338. y2 = gy + gh * 0.5
  339. bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
  340. if clip_border and max_shape is not None:
  341. # clip bboxes with dynamic `min` and `max` for onnx
  342. if torch.onnx.is_in_onnx_export():
  343. from mmdet.core.export import dynamic_clip_for_onnx
  344. x1, y1, x2, y2 = dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape)
  345. bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
  346. return bboxes
  347. if not isinstance(max_shape, torch.Tensor):
  348. max_shape = x1.new_tensor(max_shape)
  349. max_shape = max_shape[..., :2].type_as(x1)
  350. if max_shape.ndim == 2:
  351. assert bboxes.ndim == 3
  352. assert max_shape.size(0) == bboxes.size(0)
  353. min_xy = x1.new_tensor(0)
  354. max_xy = torch.cat(
  355. [max_shape] * (deltas.size(-1) // 2),
  356. dim=-1).flip(-1).unsqueeze(-2)
  357. bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
  358. bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
  359. return bboxes