tblr_bbox_coder.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Sequence, Union
  3. import torch
  4. from torch import Tensor
  5. from mmdet.registry import TASK_UTILS
  6. from mmdet.structures.bbox import BaseBoxes, HorizontalBoxes, get_box_tensor
  7. from .base_bbox_coder import BaseBBoxCoder
  8. @TASK_UTILS.register_module()
  9. class TBLRBBoxCoder(BaseBBoxCoder):
  10. """TBLR BBox coder.
  11. Following the practice in `FSAF <https://arxiv.org/abs/1903.00621>`_,
  12. this coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left,
  13. right) and decode it back to the original.
  14. Args:
  15. normalizer (list | float): Normalization factor to be
  16. divided with when coding the coordinates. If it is a list, it should
  17. have length of 4 indicating normalization factor in tblr dims.
  18. Otherwise it is a unified float factor for all dims. Default: 4.0
  19. clip_border (bool, optional): Whether clip the objects outside the
  20. border of the image. Defaults to True.
  21. """
  22. def __init__(self,
  23. normalizer: Union[Sequence[float], float] = 4.0,
  24. clip_border: bool = True,
  25. **kwargs) -> None:
  26. super().__init__(**kwargs)
  27. self.normalizer = normalizer
  28. self.clip_border = clip_border
  29. def encode(self, bboxes: Union[Tensor, BaseBoxes],
  30. gt_bboxes: Union[Tensor, BaseBoxes]) -> Tensor:
  31. """Get box regression transformation deltas that can be used to
  32. transform the ``bboxes`` into the ``gt_bboxes`` in the (top, left,
  33. bottom, right) order.
  34. Args:
  35. bboxes (torch.Tensor or :obj:`BaseBoxes`): source boxes,
  36. e.g., object proposals.
  37. gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): target of the
  38. transformation, e.g., ground truth boxes.
  39. Returns:
  40. torch.Tensor: Box transformation deltas
  41. """
  42. bboxes = get_box_tensor(bboxes)
  43. gt_bboxes = get_box_tensor(gt_bboxes)
  44. assert bboxes.size(0) == gt_bboxes.size(0)
  45. assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
  46. encoded_bboxes = bboxes2tblr(
  47. bboxes, gt_bboxes, normalizer=self.normalizer)
  48. return encoded_bboxes
  49. def decode(
  50. self,
  51. bboxes: Union[Tensor, BaseBoxes],
  52. pred_bboxes: Tensor,
  53. max_shape: Optional[Union[Sequence[int], Tensor,
  54. Sequence[Sequence[int]]]] = None
  55. ) -> Union[Tensor, BaseBoxes]:
  56. """Apply transformation `pred_bboxes` to `boxes`.
  57. Args:
  58. bboxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes.Shape
  59. (B, N, 4) or (N, 4)
  60. pred_bboxes (torch.Tensor): Encoded boxes with shape
  61. (B, N, 4) or (N, 4)
  62. max_shape (Sequence[int] or torch.Tensor or Sequence[
  63. Sequence[int]],optional): Maximum bounds for boxes, specifies
  64. (H, W, C) or (H, W). If bboxes shape is (B, N, 4), then
  65. the max_shape should be a Sequence[Sequence[int]]
  66. and the length of max_shape should also be B.
  67. Returns:
  68. Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes.
  69. """
  70. bboxes = get_box_tensor(bboxes)
  71. decoded_bboxes = tblr2bboxes(
  72. bboxes,
  73. pred_bboxes,
  74. normalizer=self.normalizer,
  75. max_shape=max_shape,
  76. clip_border=self.clip_border)
  77. if self.use_box_type:
  78. decoded_bboxes = HorizontalBoxes(decoded_bboxes)
  79. return decoded_bboxes
  80. def bboxes2tblr(priors: Tensor,
  81. gts: Tensor,
  82. normalizer: Union[Sequence[float], float] = 4.0,
  83. normalize_by_wh: bool = True) -> Tensor:
  84. """Encode ground truth boxes to tblr coordinate.
  85. It first convert the gt coordinate to tblr format,
  86. (top, bottom, left, right), relative to prior box centers.
  87. The tblr coordinate may be normalized by the side length of prior bboxes
  88. if `normalize_by_wh` is specified as True, and it is then normalized by
  89. the `normalizer` factor.
  90. Args:
  91. priors (Tensor): Prior boxes in point form
  92. Shape: (num_proposals,4).
  93. gts (Tensor): Coords of ground truth for each prior in point-form
  94. Shape: (num_proposals, 4).
  95. normalizer (Sequence[float] | float): normalization parameter of
  96. encoded boxes. If it is a list, it has to have length = 4.
  97. Default: 4.0
  98. normalize_by_wh (bool): Whether to normalize tblr coordinate by the
  99. side length (wh) of prior bboxes.
  100. Return:
  101. encoded boxes (Tensor), Shape: (num_proposals, 4)
  102. """
  103. # dist b/t match center and prior's center
  104. if not isinstance(normalizer, float):
  105. normalizer = torch.tensor(normalizer, device=priors.device)
  106. assert len(normalizer) == 4, 'Normalizer must have length = 4'
  107. assert priors.size(0) == gts.size(0)
  108. prior_centers = (priors[:, 0:2] + priors[:, 2:4]) / 2
  109. xmin, ymin, xmax, ymax = gts.split(1, dim=1)
  110. top = prior_centers[:, 1].unsqueeze(1) - ymin
  111. bottom = ymax - prior_centers[:, 1].unsqueeze(1)
  112. left = prior_centers[:, 0].unsqueeze(1) - xmin
  113. right = xmax - prior_centers[:, 0].unsqueeze(1)
  114. loc = torch.cat((top, bottom, left, right), dim=1)
  115. if normalize_by_wh:
  116. # Normalize tblr by anchor width and height
  117. wh = priors[:, 2:4] - priors[:, 0:2]
  118. w, h = torch.split(wh, 1, dim=1)
  119. loc[:, :2] /= h # tb is normalized by h
  120. loc[:, 2:] /= w # lr is normalized by w
  121. # Normalize tblr by the given normalization factor
  122. return loc / normalizer
  123. def tblr2bboxes(priors: Tensor,
  124. tblr: Tensor,
  125. normalizer: Union[Sequence[float], float] = 4.0,
  126. normalize_by_wh: bool = True,
  127. max_shape: Optional[Union[Sequence[int], Tensor,
  128. Sequence[Sequence[int]]]] = None,
  129. clip_border: bool = True) -> Tensor:
  130. """Decode tblr outputs to prediction boxes.
  131. The process includes 3 steps: 1) De-normalize tblr coordinates by
  132. multiplying it with `normalizer`; 2) De-normalize tblr coordinates by the
  133. prior bbox width and height if `normalize_by_wh` is `True`; 3) Convert
  134. tblr (top, bottom, left, right) pair relative to the center of priors back
  135. to (xmin, ymin, xmax, ymax) coordinate.
  136. Args:
  137. priors (Tensor): Prior boxes in point form (x0, y0, x1, y1)
  138. Shape: (N,4) or (B, N, 4).
  139. tblr (Tensor): Coords of network output in tblr form
  140. Shape: (N, 4) or (B, N, 4).
  141. normalizer (Sequence[float] | float): Normalization parameter of
  142. encoded boxes. By list, it represents the normalization factors at
  143. tblr dims. By float, it is the unified normalization factor at all
  144. dims. Default: 4.0
  145. normalize_by_wh (bool): Whether the tblr coordinates have been
  146. normalized by the side length (wh) of prior bboxes.
  147. max_shape (Sequence[int] or torch.Tensor or Sequence[
  148. Sequence[int]],optional): Maximum bounds for boxes, specifies
  149. (H, W, C) or (H, W). If priors shape is (B, N, 4), then
  150. the max_shape should be a Sequence[Sequence[int]]
  151. and the length of max_shape should also be B.
  152. clip_border (bool, optional): Whether clip the objects outside the
  153. border of the image. Defaults to True.
  154. Return:
  155. encoded boxes (Tensor): Boxes with shape (N, 4) or (B, N, 4)
  156. """
  157. if not isinstance(normalizer, float):
  158. normalizer = torch.tensor(normalizer, device=priors.device)
  159. assert len(normalizer) == 4, 'Normalizer must have length = 4'
  160. assert priors.size(0) == tblr.size(0)
  161. if priors.ndim == 3:
  162. assert priors.size(1) == tblr.size(1)
  163. loc_decode = tblr * normalizer
  164. prior_centers = (priors[..., 0:2] + priors[..., 2:4]) / 2
  165. if normalize_by_wh:
  166. wh = priors[..., 2:4] - priors[..., 0:2]
  167. w, h = torch.split(wh, 1, dim=-1)
  168. # Inplace operation with slice would failed for exporting to ONNX
  169. th = h * loc_decode[..., :2] # tb
  170. tw = w * loc_decode[..., 2:] # lr
  171. loc_decode = torch.cat([th, tw], dim=-1)
  172. # Cannot be exported using onnx when loc_decode.split(1, dim=-1)
  173. top, bottom, left, right = loc_decode.split((1, 1, 1, 1), dim=-1)
  174. xmin = prior_centers[..., 0].unsqueeze(-1) - left
  175. xmax = prior_centers[..., 0].unsqueeze(-1) + right
  176. ymin = prior_centers[..., 1].unsqueeze(-1) - top
  177. ymax = prior_centers[..., 1].unsqueeze(-1) + bottom
  178. bboxes = torch.cat((xmin, ymin, xmax, ymax), dim=-1)
  179. if clip_border and max_shape is not None:
  180. # clip bboxes with dynamic `min` and `max` for onnx
  181. if torch.onnx.is_in_onnx_export():
  182. from mmdet.core.export import dynamic_clip_for_onnx
  183. xmin, ymin, xmax, ymax = dynamic_clip_for_onnx(
  184. xmin, ymin, xmax, ymax, max_shape)
  185. bboxes = torch.cat([xmin, ymin, xmax, ymax], dim=-1)
  186. return bboxes
  187. if not isinstance(max_shape, torch.Tensor):
  188. max_shape = priors.new_tensor(max_shape)
  189. max_shape = max_shape[..., :2].type_as(priors)
  190. if max_shape.ndim == 2:
  191. assert bboxes.ndim == 3
  192. assert max_shape.size(0) == bboxes.size(0)
  193. min_xy = priors.new_tensor(0)
  194. max_xy = torch.cat([max_shape, max_shape],
  195. dim=-1).flip(-1).unsqueeze(-2)
  196. bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
  197. bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
  198. return bboxes