yxyx_bbox_coder.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. import numpy as np
  4. import torch
  5. from mmdet.models.task_modules.coders.delta_xywh_bbox_coder import \
  6. DeltaXYWHBBoxCoder
  7. from mmdet.registry import TASK_UTILS
  8. from mmdet.structures.bbox import HorizontalBoxes, get_box_tensor
  9. @TASK_UTILS.register_module()
  10. class YXYXDeltaXYWHBBoxCoder(DeltaXYWHBBoxCoder):
  11. def encode(self, bboxes, gt_bboxes):
  12. """Get box regression transformation deltas that can be used to
  13. transform the ``bboxes`` into the ``gt_bboxes``.
  14. Args:
  15. bboxes (torch.Tensor or :obj:`BaseBoxes`): Source boxes,
  16. e.g., object proposals.
  17. gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): Target of the
  18. transformation, e.g., ground-truth boxes.
  19. Returns:
  20. torch.Tensor: Box transformation deltas
  21. """
  22. bboxes = get_box_tensor(bboxes)
  23. gt_bboxes = get_box_tensor(gt_bboxes)
  24. assert bboxes.size(0) == gt_bboxes.size(0)
  25. assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
  26. encoded_bboxes = YXbbox2delta(bboxes, gt_bboxes, self.means, self.stds)
  27. return encoded_bboxes
  28. def decode(self,
  29. bboxes,
  30. pred_bboxes,
  31. max_shape=None,
  32. wh_ratio_clip=16 / 1000):
  33. """Apply transformation `pred_bboxes` to `boxes`.
  34. Args:
  35. bboxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes. Shape
  36. (B, N, 4) or (N, 4)
  37. pred_bboxes (Tensor): Encoded offsets with respect to each roi.
  38. Has shape (B, N, num_classes * 4) or (B, N, 4) or
  39. (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
  40. when rois is a grid of anchors.Offset encoding follows [1]_.
  41. max_shape (Sequence[int] or torch.Tensor or Sequence[
  42. Sequence[int]],optional): Maximum bounds for boxes, specifies
  43. (H, W, C) or (H, W). If bboxes shape is (B, N, 4), then
  44. the max_shape should be a Sequence[Sequence[int]]
  45. and the length of max_shape should also be B.
  46. wh_ratio_clip (float, optional): The allowed ratio between
  47. width and height.
  48. Returns:
  49. Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes.
  50. """
  51. bboxes = get_box_tensor(bboxes)
  52. assert pred_bboxes.size(0) == bboxes.size(0)
  53. if pred_bboxes.ndim == 3:
  54. assert pred_bboxes.size(1) == bboxes.size(1)
  55. if pred_bboxes.ndim == 2 and not torch.onnx.is_in_onnx_export():
  56. # single image decode
  57. decoded_bboxes = YXdelta2bbox(bboxes, pred_bboxes, self.means,
  58. self.stds, max_shape, wh_ratio_clip,
  59. self.clip_border, self.add_ctr_clamp,
  60. self.ctr_clamp)
  61. else:
  62. if pred_bboxes.ndim == 3 and not torch.onnx.is_in_onnx_export():
  63. warnings.warn(
  64. 'DeprecationWarning: onnx_delta2bbox is deprecated '
  65. 'in the case of batch decoding and non-ONNX, '
  66. 'please use “delta2bbox” instead. In order to improve '
  67. 'the decoding speed, the batch function will no '
  68. 'longer be supported. ')
  69. decoded_bboxes = YXonnx_delta2bbox(bboxes, pred_bboxes, self.means,
  70. self.stds, max_shape,
  71. wh_ratio_clip, self.clip_border,
  72. self.add_ctr_clamp,
  73. self.ctr_clamp)
  74. if self.use_box_type:
  75. assert decoded_bboxes.size(-1) == 4, \
  76. ('Cannot warp decoded boxes with box type when decoded boxes'
  77. 'have shape of (N, num_classes * 4)')
  78. decoded_bboxes = HorizontalBoxes(decoded_bboxes)
  79. return decoded_bboxes
  80. def YXdelta2bbox(rois,
  81. deltas,
  82. means=(0., 0., 0., 0.),
  83. stds=(1., 1., 1., 1.),
  84. max_shape=None,
  85. hw_ratio_clip=1000 / 16,
  86. clip_border=True,
  87. add_ctr_clamp=False,
  88. ctr_clamp=32):
  89. """Apply deltas to shift/scale base boxes.
  90. Typically the rois are anchor or proposed bounding boxes and the deltas are
  91. network outputs used to shift/scale those boxes.
  92. This is the inverse function of :func:`bbox2delta`.
  93. Args:
  94. rois (Tensor): Boxes to be transformed. Has shape (N, 4).
  95. deltas (Tensor): Encoded offsets relative to each roi.
  96. Has shape (N, num_classes * 4) or (N, 4). Note
  97. N = num_base_anchors * W * H, when rois is a grid of
  98. anchors. Offset encoding follows [1]_.
  99. means (Sequence[float]): Denormalizing means for delta coordinates.
  100. Default (0., 0., 0., 0.).
  101. stds (Sequence[float]): Denormalizing standard deviation for delta
  102. coordinates. Default (1., 1., 1., 1.).
  103. max_shape (tuple[int, int]): Maximum bounds for boxes, specifies
  104. (H, W). Default None.
  105. wh_ratio_clip (float): Maximum aspect ratio for boxes. Default
  106. 16 / 1000.
  107. clip_border (bool, optional): Whether clip the objects outside the
  108. border of the image. Default True.
  109. add_ctr_clamp (bool): Whether to add center clamp. When set to True,
  110. the center of the prediction bounding box will be clamped to
  111. avoid being too far away from the center of the anchor.
  112. Only used by YOLOF. Default False.
  113. ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
  114. Default 32.
  115. Returns:
  116. Tensor: Boxes with shape (N, num_classes * 4) or (N, 4), where 4
  117. represent tl_x, tl_y, br_x, br_y.
  118. References:
  119. .. [1] https://arxiv.org/abs/1311.2524
  120. Example:
  121. >>> rois = torch.Tensor([[ 0., 0., 1., 1.],
  122. >>> [ 0., 0., 1., 1.],
  123. >>> [ 0., 0., 1., 1.],
  124. >>> [ 5., 5., 5., 5.]])
  125. >>> deltas = torch.Tensor([[ 0., 0., 0., 0.],
  126. >>> [ 1., 1., 1., 1.],
  127. >>> [ 0., 0., 2., -1.],
  128. >>> [ 0.7, -1.9, -0.5, 0.3]])
  129. >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3))
  130. tensor([[0.0000, 0.0000, 1.0000, 1.0000],
  131. [0.1409, 0.1409, 2.8591, 2.8591],
  132. [0.0000, 0.3161, 4.1945, 0.6839],
  133. [5.0000, 5.0000, 5.0000, 5.0000]])
  134. """
  135. num_bboxes, num_classes = deltas.size(0), deltas.size(1) // 4
  136. if num_bboxes == 0:
  137. return deltas
  138. deltas = deltas.reshape(-1, 4)
  139. means = deltas.new_tensor(means).view(1, -1)
  140. stds = deltas.new_tensor(stds).view(1, -1)
  141. denorm_deltas = deltas * stds + means
  142. dyx = denorm_deltas[:, :2]
  143. dhw = denorm_deltas[:, 2:]
  144. # Compute width/height of each roi
  145. rois_ = rois.repeat(1, num_classes).reshape(-1, 4)
  146. pyx = ((rois_[:, :2] + rois_[:, 2:]) * 0.5)
  147. phw = (rois_[:, 2:] - rois_[:, :2])
  148. dyx_hw = phw * dyx
  149. max_ratio = np.abs(np.log(hw_ratio_clip))
  150. if add_ctr_clamp:
  151. dyx_hw = torch.clamp(dyx_hw, max=ctr_clamp, min=-ctr_clamp)
  152. dhw = torch.clamp(dhw, max=max_ratio)
  153. else:
  154. dhw = dhw.clamp(min=-max_ratio, max=max_ratio)
  155. gyx = pyx + dyx_hw
  156. ghw = phw * dhw.exp()
  157. y1x1 = gyx - (ghw * 0.5)
  158. y2x2 = gyx + (ghw * 0.5)
  159. ymin, xmin = y1x1[:, 0].reshape(-1, 1), y1x1[:, 1].reshape(-1, 1)
  160. ymax, xmax = y2x2[:, 0].reshape(-1, 1), y2x2[:, 1].reshape(-1, 1)
  161. bboxes = torch.cat([xmin, ymin, xmax, ymax], dim=-1)
  162. if clip_border and max_shape is not None:
  163. bboxes[..., 0::2].clamp_(min=0, max=max_shape[1])
  164. bboxes[..., 1::2].clamp_(min=0, max=max_shape[0])
  165. bboxes = bboxes.reshape(num_bboxes, -1)
  166. return bboxes
  167. def YXbbox2delta(proposals, gt, means=(0., 0., 0., 0.), stds=(1., 1., 1., 1.)):
  168. """Compute deltas of proposals w.r.t. gt.
  169. We usually compute the deltas of x, y, w, h of proposals w.r.t ground
  170. truth bboxes to get regression target.
  171. This is the inverse function of :func:`delta2bbox`.
  172. Args:
  173. proposals (Tensor): Boxes to be transformed, shape (N, ..., 4)
  174. gt (Tensor): Gt bboxes to be used as base, shape (N, ..., 4)
  175. means (Sequence[float]): Denormalizing means for delta coordinates
  176. stds (Sequence[float]): Denormalizing standard deviation for delta
  177. coordinates
  178. Returns:
  179. Tensor: deltas with shape (N, 4), where columns represent dx, dy,
  180. dw, dh.
  181. """
  182. assert proposals.size() == gt.size()
  183. proposals = proposals.float()
  184. gt = gt.float()
  185. py = (proposals[..., 0] + proposals[..., 2]) * 0.5
  186. px = (proposals[..., 1] + proposals[..., 3]) * 0.5
  187. ph = proposals[..., 2] - proposals[..., 0]
  188. pw = proposals[..., 3] - proposals[..., 1]
  189. gx = (gt[..., 0] + gt[..., 2]) * 0.5
  190. gy = (gt[..., 1] + gt[..., 3]) * 0.5
  191. gw = gt[..., 2] - gt[..., 0]
  192. gh = gt[..., 3] - gt[..., 1]
  193. dx = (gx - px) / pw
  194. dy = (gy - py) / ph
  195. dw = torch.log(gw / pw)
  196. dh = torch.log(gh / ph)
  197. deltas = torch.stack([dy, dx, dh, dw], dim=-1)
  198. means = deltas.new_tensor(means).unsqueeze(0)
  199. stds = deltas.new_tensor(stds).unsqueeze(0)
  200. deltas = deltas.sub_(means).div_(stds)
  201. return deltas
  202. def YXonnx_delta2bbox(rois,
  203. deltas,
  204. means=(0., 0., 0., 0.),
  205. stds=(1., 1., 1., 1.),
  206. max_shape=None,
  207. wh_ratio_clip=16 / 1000,
  208. clip_border=True,
  209. add_ctr_clamp=False,
  210. ctr_clamp=32):
  211. """Apply deltas to shift/scale base boxes.
  212. Typically the rois are anchor or proposed bounding boxes and the deltas are
  213. network outputs used to shift/scale those boxes.
  214. This is the inverse function of :func:`bbox2delta`.
  215. Args:
  216. rois (Tensor): Boxes to be transformed. Has shape (N, 4) or (B, N, 4)
  217. deltas (Tensor): Encoded offsets with respect to each roi.
  218. Has shape (B, N, num_classes * 4) or (B, N, 4) or
  219. (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
  220. when rois is a grid of anchors.Offset encoding follows [1]_.
  221. means (Sequence[float]): Denormalizing means for delta coordinates.
  222. Default (0., 0., 0., 0.).
  223. stds (Sequence[float]): Denormalizing standard deviation for delta
  224. coordinates. Default (1., 1., 1., 1.).
  225. max_shape (Sequence[int] or torch.Tensor or Sequence[
  226. Sequence[int]],optional): Maximum bounds for boxes, specifies
  227. (H, W, C) or (H, W). If rois shape is (B, N, 4), then
  228. the max_shape should be a Sequence[Sequence[int]]
  229. and the length of max_shape should also be B. Default None.
  230. wh_ratio_clip (float): Maximum aspect ratio for boxes.
  231. Default 16 / 1000.
  232. clip_border (bool, optional): Whether clip the objects outside the
  233. border of the image. Default True.
  234. add_ctr_clamp (bool): Whether to add center clamp, when added, the
  235. predicted box is clamped is its center is too far away from
  236. the original anchor's center. Only used by YOLOF. Default False.
  237. ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
  238. Default 32.
  239. Returns:
  240. Tensor: Boxes with shape (B, N, num_classes * 4) or (B, N, 4) or
  241. (N, num_classes * 4) or (N, 4), where 4 represent
  242. tl_x, tl_y, br_x, br_y.
  243. References:
  244. .. [1] https://arxiv.org/abs/1311.2524
  245. Example:
  246. >>> rois = torch.Tensor([[ 0., 0., 1., 1.],
  247. >>> [ 0., 0., 1., 1.],
  248. >>> [ 0., 0., 1., 1.],
  249. >>> [ 5., 5., 5., 5.]])
  250. >>> deltas = torch.Tensor([[ 0., 0., 0., 0.],
  251. >>> [ 1., 1., 1., 1.],
  252. >>> [ 0., 0., 2., -1.],
  253. >>> [ 0.7, -1.9, -0.5, 0.3]])
  254. >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3))
  255. tensor([[0.0000, 0.0000, 1.0000, 1.0000],
  256. [0.1409, 0.1409, 2.8591, 2.8591],
  257. [0.0000, 0.3161, 4.1945, 0.6839],
  258. [5.0000, 5.0000, 5.0000, 5.0000]])
  259. """
  260. means = deltas.new_tensor(means).view(1,
  261. -1).repeat(1,
  262. deltas.size(-1) // 4)
  263. stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(-1) // 4)
  264. denorm_deltas = deltas * stds + means
  265. dy = denorm_deltas[..., 0::4]
  266. dx = denorm_deltas[..., 1::4]
  267. dh = denorm_deltas[..., 2::4]
  268. dw = denorm_deltas[..., 3::4]
  269. y1, x1 = rois[..., 0], rois[..., 1]
  270. y2, x2 = rois[..., 2], rois[..., 3]
  271. # Compute center of each roi
  272. px = ((x1 + x2) * 0.5).unsqueeze(-1).expand_as(dx)
  273. py = ((y1 + y2) * 0.5).unsqueeze(-1).expand_as(dy)
  274. # Compute width/height of each roi
  275. pw = (x2 - x1).unsqueeze(-1).expand_as(dw)
  276. ph = (y2 - y1).unsqueeze(-1).expand_as(dh)
  277. dx_width = pw * dx
  278. dy_height = ph * dy
  279. max_ratio = np.abs(np.log(wh_ratio_clip))
  280. if add_ctr_clamp:
  281. dx_width = torch.clamp(dx_width, max=ctr_clamp, min=-ctr_clamp)
  282. dy_height = torch.clamp(dy_height, max=ctr_clamp, min=-ctr_clamp)
  283. dw = torch.clamp(dw, max=max_ratio)
  284. dh = torch.clamp(dh, max=max_ratio)
  285. else:
  286. dw = dw.clamp(min=-max_ratio, max=max_ratio)
  287. dh = dh.clamp(min=-max_ratio, max=max_ratio)
  288. # Use exp(network energy) to enlarge/shrink each roi
  289. gw = pw * dw.exp()
  290. gh = ph * dh.exp()
  291. # Use network energy to shift the center of each roi
  292. gx = px + dx_width
  293. gy = py + dy_height
  294. # Convert center-xy/width/height to top-left, bottom-right
  295. x1 = gx - gw * 0.5
  296. y1 = gy - gh * 0.5
  297. x2 = gx + gw * 0.5
  298. y2 = gy + gh * 0.5
  299. bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
  300. if clip_border and max_shape is not None:
  301. # clip bboxes with dynamic `min` and `max` for onnx
  302. if torch.onnx.is_in_onnx_export():
  303. from mmdet.core.export import dynamic_clip_for_onnx
  304. x1, y1, x2, y2 = dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape)
  305. bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
  306. return bboxes
  307. if not isinstance(max_shape, torch.Tensor):
  308. max_shape = x1.new_tensor(max_shape)
  309. max_shape = max_shape[..., :2].type_as(x1)
  310. if max_shape.ndim == 2:
  311. assert bboxes.ndim == 3
  312. assert max_shape.size(0) == bboxes.size(0)
  313. min_xy = x1.new_tensor(0)
  314. max_xy = torch.cat(
  315. [max_shape] * (deltas.size(-1) // 2),
  316. dim=-1).flip(-1).unsqueeze(-2)
  317. bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
  318. bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
  319. return bboxes