distance_point_bbox_coder.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Sequence, Union
  3. from torch import Tensor
  4. from mmdet.registry import TASK_UTILS
  5. from mmdet.structures.bbox import (BaseBoxes, HorizontalBoxes, bbox2distance,
  6. distance2bbox, get_box_tensor)
  7. from .base_bbox_coder import BaseBBoxCoder
  8. @TASK_UTILS.register_module()
  9. class DistancePointBBoxCoder(BaseBBoxCoder):
  10. """Distance Point BBox coder.
  11. This coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left,
  12. right) and decode it back to the original.
  13. Args:
  14. clip_border (bool, optional): Whether clip the objects outside the
  15. border of the image. Defaults to True.
  16. """
  17. def __init__(self, clip_border: Optional[bool] = True, **kwargs) -> None:
  18. super().__init__(**kwargs)
  19. self.clip_border = clip_border
  20. def encode(self,
  21. points: Tensor,
  22. gt_bboxes: Union[Tensor, BaseBoxes],
  23. max_dis: Optional[float] = None,
  24. eps: float = 0.1) -> Tensor:
  25. """Encode bounding box to distances.
  26. Args:
  27. points (Tensor): Shape (N, 2), The format is [x, y].
  28. gt_bboxes (Tensor or :obj:`BaseBoxes`): Shape (N, 4), The format
  29. is "xyxy"
  30. max_dis (float): Upper bound of the distance. Default None.
  31. eps (float): a small value to ensure target < max_dis, instead <=.
  32. Default 0.1.
  33. Returns:
  34. Tensor: Box transformation deltas. The shape is (N, 4).
  35. """
  36. gt_bboxes = get_box_tensor(gt_bboxes)
  37. assert points.size(0) == gt_bboxes.size(0)
  38. assert points.size(-1) == 2
  39. assert gt_bboxes.size(-1) == 4
  40. return bbox2distance(points, gt_bboxes, max_dis, eps)
  41. def decode(
  42. self,
  43. points: Tensor,
  44. pred_bboxes: Tensor,
  45. max_shape: Optional[Union[Sequence[int], Tensor,
  46. Sequence[Sequence[int]]]] = None
  47. ) -> Union[Tensor, BaseBoxes]:
  48. """Decode distance prediction to bounding box.
  49. Args:
  50. points (Tensor): Shape (B, N, 2) or (N, 2).
  51. pred_bboxes (Tensor): Distance from the given point to 4
  52. boundaries (left, top, right, bottom). Shape (B, N, 4)
  53. or (N, 4)
  54. max_shape (Sequence[int] or torch.Tensor or Sequence[
  55. Sequence[int]],optional): Maximum bounds for boxes, specifies
  56. (H, W, C) or (H, W). If priors shape is (B, N, 4), then
  57. the max_shape should be a Sequence[Sequence[int]],
  58. and the length of max_shape should also be B.
  59. Default None.
  60. Returns:
  61. Union[Tensor, :obj:`BaseBoxes`]: Boxes with shape (N, 4) or
  62. (B, N, 4)
  63. """
  64. assert points.size(0) == pred_bboxes.size(0)
  65. assert points.size(-1) == 2
  66. assert pred_bboxes.size(-1) == 4
  67. if self.clip_border is False:
  68. max_shape = None
  69. bboxes = distance2bbox(points, pred_bboxes, max_shape)
  70. if self.use_box_type:
  71. bboxes = HorizontalBoxes(bboxes)
  72. return bboxes