123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Optional, Sequence, Union
- import torch
- from torch import Tensor
- from mmdet.registry import TASK_UTILS
- from mmdet.structures.bbox import BaseBoxes, HorizontalBoxes, get_box_tensor
- from .base_bbox_coder import BaseBBoxCoder
- @TASK_UTILS.register_module()
- class TBLRBBoxCoder(BaseBBoxCoder):
- """TBLR BBox coder.
- Following the practice in `FSAF <https://arxiv.org/abs/1903.00621>`_,
- this coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left,
- right) and decode it back to the original.
- Args:
- normalizer (list | float): Normalization factor to be
- divided with when coding the coordinates. If it is a list, it should
- have length of 4 indicating normalization factor in tblr dims.
- Otherwise it is a unified float factor for all dims. Default: 4.0
- clip_border (bool, optional): Whether clip the objects outside the
- border of the image. Defaults to True.
- """
- def __init__(self,
- normalizer: Union[Sequence[float], float] = 4.0,
- clip_border: bool = True,
- **kwargs) -> None:
- super().__init__(**kwargs)
- self.normalizer = normalizer
- self.clip_border = clip_border
- def encode(self, bboxes: Union[Tensor, BaseBoxes],
- gt_bboxes: Union[Tensor, BaseBoxes]) -> Tensor:
- """Get box regression transformation deltas that can be used to
- transform the ``bboxes`` into the ``gt_bboxes`` in the (top, left,
- bottom, right) order.
- Args:
- bboxes (torch.Tensor or :obj:`BaseBoxes`): source boxes,
- e.g., object proposals.
- gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): target of the
- transformation, e.g., ground truth boxes.
- Returns:
- torch.Tensor: Box transformation deltas
- """
- bboxes = get_box_tensor(bboxes)
- gt_bboxes = get_box_tensor(gt_bboxes)
- assert bboxes.size(0) == gt_bboxes.size(0)
- assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
- encoded_bboxes = bboxes2tblr(
- bboxes, gt_bboxes, normalizer=self.normalizer)
- return encoded_bboxes
- def decode(
- self,
- bboxes: Union[Tensor, BaseBoxes],
- pred_bboxes: Tensor,
- max_shape: Optional[Union[Sequence[int], Tensor,
- Sequence[Sequence[int]]]] = None
- ) -> Union[Tensor, BaseBoxes]:
- """Apply transformation `pred_bboxes` to `boxes`.
- Args:
- bboxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes.Shape
- (B, N, 4) or (N, 4)
- pred_bboxes (torch.Tensor): Encoded boxes with shape
- (B, N, 4) or (N, 4)
- max_shape (Sequence[int] or torch.Tensor or Sequence[
- Sequence[int]],optional): Maximum bounds for boxes, specifies
- (H, W, C) or (H, W). If bboxes shape is (B, N, 4), then
- the max_shape should be a Sequence[Sequence[int]]
- and the length of max_shape should also be B.
- Returns:
- Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes.
- """
- bboxes = get_box_tensor(bboxes)
- decoded_bboxes = tblr2bboxes(
- bboxes,
- pred_bboxes,
- normalizer=self.normalizer,
- max_shape=max_shape,
- clip_border=self.clip_border)
- if self.use_box_type:
- decoded_bboxes = HorizontalBoxes(decoded_bboxes)
- return decoded_bboxes
- def bboxes2tblr(priors: Tensor,
- gts: Tensor,
- normalizer: Union[Sequence[float], float] = 4.0,
- normalize_by_wh: bool = True) -> Tensor:
- """Encode ground truth boxes to tblr coordinate.
- It first convert the gt coordinate to tblr format,
- (top, bottom, left, right), relative to prior box centers.
- The tblr coordinate may be normalized by the side length of prior bboxes
- if `normalize_by_wh` is specified as True, and it is then normalized by
- the `normalizer` factor.
- Args:
- priors (Tensor): Prior boxes in point form
- Shape: (num_proposals,4).
- gts (Tensor): Coords of ground truth for each prior in point-form
- Shape: (num_proposals, 4).
- normalizer (Sequence[float] | float): normalization parameter of
- encoded boxes. If it is a list, it has to have length = 4.
- Default: 4.0
- normalize_by_wh (bool): Whether to normalize tblr coordinate by the
- side length (wh) of prior bboxes.
- Return:
- encoded boxes (Tensor), Shape: (num_proposals, 4)
- """
- # dist b/t match center and prior's center
- if not isinstance(normalizer, float):
- normalizer = torch.tensor(normalizer, device=priors.device)
- assert len(normalizer) == 4, 'Normalizer must have length = 4'
- assert priors.size(0) == gts.size(0)
- prior_centers = (priors[:, 0:2] + priors[:, 2:4]) / 2
- xmin, ymin, xmax, ymax = gts.split(1, dim=1)
- top = prior_centers[:, 1].unsqueeze(1) - ymin
- bottom = ymax - prior_centers[:, 1].unsqueeze(1)
- left = prior_centers[:, 0].unsqueeze(1) - xmin
- right = xmax - prior_centers[:, 0].unsqueeze(1)
- loc = torch.cat((top, bottom, left, right), dim=1)
- if normalize_by_wh:
- # Normalize tblr by anchor width and height
- wh = priors[:, 2:4] - priors[:, 0:2]
- w, h = torch.split(wh, 1, dim=1)
- loc[:, :2] /= h # tb is normalized by h
- loc[:, 2:] /= w # lr is normalized by w
- # Normalize tblr by the given normalization factor
- return loc / normalizer
- def tblr2bboxes(priors: Tensor,
- tblr: Tensor,
- normalizer: Union[Sequence[float], float] = 4.0,
- normalize_by_wh: bool = True,
- max_shape: Optional[Union[Sequence[int], Tensor,
- Sequence[Sequence[int]]]] = None,
- clip_border: bool = True) -> Tensor:
- """Decode tblr outputs to prediction boxes.
- The process includes 3 steps: 1) De-normalize tblr coordinates by
- multiplying it with `normalizer`; 2) De-normalize tblr coordinates by the
- prior bbox width and height if `normalize_by_wh` is `True`; 3) Convert
- tblr (top, bottom, left, right) pair relative to the center of priors back
- to (xmin, ymin, xmax, ymax) coordinate.
- Args:
- priors (Tensor): Prior boxes in point form (x0, y0, x1, y1)
- Shape: (N,4) or (B, N, 4).
- tblr (Tensor): Coords of network output in tblr form
- Shape: (N, 4) or (B, N, 4).
- normalizer (Sequence[float] | float): Normalization parameter of
- encoded boxes. By list, it represents the normalization factors at
- tblr dims. By float, it is the unified normalization factor at all
- dims. Default: 4.0
- normalize_by_wh (bool): Whether the tblr coordinates have been
- normalized by the side length (wh) of prior bboxes.
- max_shape (Sequence[int] or torch.Tensor or Sequence[
- Sequence[int]],optional): Maximum bounds for boxes, specifies
- (H, W, C) or (H, W). If priors shape is (B, N, 4), then
- the max_shape should be a Sequence[Sequence[int]]
- and the length of max_shape should also be B.
- clip_border (bool, optional): Whether clip the objects outside the
- border of the image. Defaults to True.
- Return:
- encoded boxes (Tensor): Boxes with shape (N, 4) or (B, N, 4)
- """
- if not isinstance(normalizer, float):
- normalizer = torch.tensor(normalizer, device=priors.device)
- assert len(normalizer) == 4, 'Normalizer must have length = 4'
- assert priors.size(0) == tblr.size(0)
- if priors.ndim == 3:
- assert priors.size(1) == tblr.size(1)
- loc_decode = tblr * normalizer
- prior_centers = (priors[..., 0:2] + priors[..., 2:4]) / 2
- if normalize_by_wh:
- wh = priors[..., 2:4] - priors[..., 0:2]
- w, h = torch.split(wh, 1, dim=-1)
- # Inplace operation with slice would failed for exporting to ONNX
- th = h * loc_decode[..., :2] # tb
- tw = w * loc_decode[..., 2:] # lr
- loc_decode = torch.cat([th, tw], dim=-1)
- # Cannot be exported using onnx when loc_decode.split(1, dim=-1)
- top, bottom, left, right = loc_decode.split((1, 1, 1, 1), dim=-1)
- xmin = prior_centers[..., 0].unsqueeze(-1) - left
- xmax = prior_centers[..., 0].unsqueeze(-1) + right
- ymin = prior_centers[..., 1].unsqueeze(-1) - top
- ymax = prior_centers[..., 1].unsqueeze(-1) + bottom
- bboxes = torch.cat((xmin, ymin, xmax, ymax), dim=-1)
- if clip_border and max_shape is not None:
- # clip bboxes with dynamic `min` and `max` for onnx
- if torch.onnx.is_in_onnx_export():
- from mmdet.core.export import dynamic_clip_for_onnx
- xmin, ymin, xmax, ymax = dynamic_clip_for_onnx(
- xmin, ymin, xmax, ymax, max_shape)
- bboxes = torch.cat([xmin, ymin, xmax, ymax], dim=-1)
- return bboxes
- if not isinstance(max_shape, torch.Tensor):
- max_shape = priors.new_tensor(max_shape)
- max_shape = max_shape[..., :2].type_as(priors)
- if max_shape.ndim == 2:
- assert bboxes.ndim == 3
- assert max_shape.size(0) == bboxes.size(0)
- min_xy = priors.new_tensor(0)
- max_xy = torch.cat([max_shape, max_shape],
- dim=-1).flip(-1).unsqueeze(-2)
- bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
- bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
- return bboxes
|