yolo_bbox_coder.py 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Union
  3. import torch
  4. from torch import Tensor
  5. from mmdet.registry import TASK_UTILS
  6. from mmdet.structures.bbox import BaseBoxes, HorizontalBoxes, get_box_tensor
  7. from .base_bbox_coder import BaseBBoxCoder
  8. @TASK_UTILS.register_module()
  9. class YOLOBBoxCoder(BaseBBoxCoder):
  10. """YOLO BBox coder.
  11. Following `YOLO <https://arxiv.org/abs/1506.02640>`_, this coder divide
  12. image into grids, and encode bbox (x1, y1, x2, y2) into (cx, cy, dw, dh).
  13. cx, cy in [0., 1.], denotes relative center position w.r.t the center of
  14. bboxes. dw, dh are the same as :obj:`DeltaXYWHBBoxCoder`.
  15. Args:
  16. eps (float): Min value of cx, cy when encoding.
  17. """
  18. def __init__(self, eps: float = 1e-6, **kwargs):
  19. super().__init__(**kwargs)
  20. self.eps = eps
  21. def encode(self, bboxes: Union[Tensor, BaseBoxes],
  22. gt_bboxes: Union[Tensor, BaseBoxes],
  23. stride: Union[Tensor, int]) -> Tensor:
  24. """Get box regression transformation deltas that can be used to
  25. transform the ``bboxes`` into the ``gt_bboxes``.
  26. Args:
  27. bboxes (torch.Tensor or :obj:`BaseBoxes`): Source boxes,
  28. e.g., anchors.
  29. gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): Target of the
  30. transformation, e.g., ground-truth boxes.
  31. stride (torch.Tensor | int): Stride of bboxes.
  32. Returns:
  33. torch.Tensor: Box transformation deltas
  34. """
  35. bboxes = get_box_tensor(bboxes)
  36. gt_bboxes = get_box_tensor(gt_bboxes)
  37. assert bboxes.size(0) == gt_bboxes.size(0)
  38. assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
  39. x_center_gt = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) * 0.5
  40. y_center_gt = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) * 0.5
  41. w_gt = gt_bboxes[..., 2] - gt_bboxes[..., 0]
  42. h_gt = gt_bboxes[..., 3] - gt_bboxes[..., 1]
  43. x_center = (bboxes[..., 0] + bboxes[..., 2]) * 0.5
  44. y_center = (bboxes[..., 1] + bboxes[..., 3]) * 0.5
  45. w = bboxes[..., 2] - bboxes[..., 0]
  46. h = bboxes[..., 3] - bboxes[..., 1]
  47. w_target = torch.log((w_gt / w).clamp(min=self.eps))
  48. h_target = torch.log((h_gt / h).clamp(min=self.eps))
  49. x_center_target = ((x_center_gt - x_center) / stride + 0.5).clamp(
  50. self.eps, 1 - self.eps)
  51. y_center_target = ((y_center_gt - y_center) / stride + 0.5).clamp(
  52. self.eps, 1 - self.eps)
  53. encoded_bboxes = torch.stack(
  54. [x_center_target, y_center_target, w_target, h_target], dim=-1)
  55. return encoded_bboxes
  56. def decode(self, bboxes: Union[Tensor, BaseBoxes], pred_bboxes: Tensor,
  57. stride: Union[Tensor, int]) -> Union[Tensor, BaseBoxes]:
  58. """Apply transformation `pred_bboxes` to `boxes`.
  59. Args:
  60. boxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes,
  61. e.g. anchors.
  62. pred_bboxes (torch.Tensor): Encoded boxes with shape
  63. stride (torch.Tensor | int): Strides of bboxes.
  64. Returns:
  65. Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes.
  66. """
  67. bboxes = get_box_tensor(bboxes)
  68. assert pred_bboxes.size(-1) == bboxes.size(-1) == 4
  69. xy_centers = (bboxes[..., :2] + bboxes[..., 2:]) * 0.5 + (
  70. pred_bboxes[..., :2] - 0.5) * stride
  71. whs = (bboxes[..., 2:] -
  72. bboxes[..., :2]) * 0.5 * pred_bboxes[..., 2:].exp()
  73. decoded_bboxes = torch.stack(
  74. (xy_centers[..., 0] - whs[..., 0], xy_centers[..., 1] -
  75. whs[..., 1], xy_centers[..., 0] + whs[..., 0],
  76. xy_centers[..., 1] + whs[..., 1]),
  77. dim=-1)
  78. if self.use_box_type:
  79. decoded_bboxes = HorizontalBoxes(decoded_bboxes)
  80. return decoded_bboxes