123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from copy import deepcopy
- from typing import List, Optional, Sequence, Tuple, Type, TypeVar, Union
- import numpy as np
- import torch
- from mmdet.structures.bbox import HorizontalBoxes
- from torch import Tensor
- DeviceType = Union[str, torch.device]
- T = TypeVar('T')
- IndexType = Union[slice, int, list, torch.LongTensor, torch.cuda.LongTensor,
- torch.BoolTensor, torch.cuda.BoolTensor, np.ndarray]
- class BBoxKeypoints(HorizontalBoxes):
- """The BBoxKeypoints class is a combination of bounding boxes and keypoints
- representation. The box format used in BBoxKeypoints is the same as
- HorizontalBoxes.
- Args:
- data (Tensor or np.ndarray): The box data with shape of
- (N, 4).
- keypoints (Tensor or np.ndarray): The keypoint data with shape of
- (N, K, 2).
- keypoints_visible (Tensor or np.ndarray): The visibility of keypoints
- with shape of (N, K).
- dtype (torch.dtype, Optional): data type of boxes. Defaults to None.
- device (str or torch.device, Optional): device of boxes.
- Default to None.
- clone (bool): Whether clone ``boxes`` or not. Defaults to True.
- mode (str, Optional): the mode of boxes. If it is 'cxcywh', the
- `data` will be converted to 'xyxy' mode. Defaults to None.
- flip_indices (list, Optional): The indices of keypoints when the
- images is flipped. Defaults to None.
- Notes:
- N: the number of instances.
- K: the number of keypoints.
- """
- def __init__(self,
- data: Union[Tensor, np.ndarray],
- keypoints: Union[Tensor, np.ndarray],
- keypoints_visible: Union[Tensor, np.ndarray],
- dtype: Optional[torch.dtype] = None,
- device: Optional[DeviceType] = None,
- clone: bool = True,
- in_mode: Optional[str] = None,
- flip_indices: Optional[List] = None) -> None:
- super().__init__(
- data=data,
- dtype=dtype,
- device=device,
- clone=clone,
- in_mode=in_mode)
- assert len(data) == len(keypoints)
- assert len(data) == len(keypoints_visible)
- assert keypoints.ndim == 3
- assert keypoints_visible.ndim == 2
- keypoints = torch.as_tensor(keypoints)
- keypoints_visible = torch.as_tensor(keypoints_visible)
- if device is not None:
- keypoints = keypoints.to(device=device)
- keypoints_visible = keypoints_visible.to(device=device)
- if clone:
- keypoints = keypoints.clone()
- keypoints_visible = keypoints_visible.clone()
- self.keypoints = keypoints
- self.keypoints_visible = keypoints_visible
- self.flip_indices = flip_indices
- def flip_(self,
- img_shape: Tuple[int, int],
- direction: str = 'horizontal') -> None:
- """Flip boxes & kpts horizontally in-place.
- Args:
- img_shape (Tuple[int, int]): A tuple of image height and width.
- direction (str): Flip direction, options are "horizontal",
- "vertical" and "diagonal". Defaults to "horizontal"
- """
- assert direction == 'horizontal'
- super().flip_(img_shape, direction)
- self.keypoints[..., 0] = img_shape[1] - self.keypoints[..., 0]
- self.keypoints = self.keypoints[:, self.flip_indices]
- self.keypoints_visible = self.keypoints_visible[:, self.flip_indices]
- def translate_(self, distances: Tuple[float, float]) -> None:
- """Translate boxes and keypoints in-place.
- Args:
- distances (Tuple[float, float]): translate distances. The first
- is horizontal distance and the second is vertical distance.
- """
- boxes = self.tensor
- assert len(distances) == 2
- self.tensor = boxes + boxes.new_tensor(distances).repeat(2)
- distances = self.keypoints.new_tensor(distances).reshape(1, 1, 2)
- self.keypoints = self.keypoints + distances
- def rescale_(self, scale_factor: Tuple[float, float]) -> None:
- """Rescale boxes & keypoints w.r.t. rescale_factor in-place.
- Note:
- Both ``rescale_`` and ``resize_`` will enlarge or shrink boxes
- w.r.t ``scale_facotr``. The difference is that ``resize_`` only
- changes the width and the height of boxes, but ``rescale_`` also
- rescales the box centers simultaneously.
- Args:
- scale_factor (Tuple[float, float]): factors for scaling boxes.
- The length should be 2.
- """
- boxes = self.tensor
- assert len(scale_factor) == 2
- self.tensor = boxes * boxes.new_tensor(scale_factor).repeat(2)
- scale_factor = self.keypoints.new_tensor(scale_factor).reshape(1, 1, 2)
- self.keypoints = self.keypoints * scale_factor
- def clip_(self, img_shape: Tuple[int, int]) -> None:
- """Clip bounding boxes and set invisible keypoints outside the image
- boundary in-place.
- Args:
- img_shape (Tuple[int, int]): A tuple of image height and width.
- """
- boxes = self.tensor
- boxes[..., 0::2] = boxes[..., 0::2].clamp(0, img_shape[1])
- boxes[..., 1::2] = boxes[..., 1::2].clamp(0, img_shape[0])
- kpt_outside = torch.logical_or(
- torch.logical_or(self.keypoints[..., 0] < 0,
- self.keypoints[..., 1] < 0),
- torch.logical_or(self.keypoints[..., 0] > img_shape[1],
- self.keypoints[..., 1] > img_shape[0]))
- self.keypoints_visible[kpt_outside] *= 0
- def project_(self, homography_matrix: Union[Tensor, np.ndarray]) -> None:
- """Geometrically transform bounding boxes and keypoints in-place using
- a homography matrix.
- Args:
- homography_matrix (Tensor or np.ndarray): A 3x3 tensor or ndarray
- representing the homography matrix for the transformation.
- """
- boxes = self.tensor
- if isinstance(homography_matrix, np.ndarray):
- homography_matrix = boxes.new_tensor(homography_matrix)
- # Convert boxes to corners in homogeneous coordinates
- corners = self.hbox2corner(boxes)
- corners = torch.cat(
- [corners, corners.new_ones(*corners.shape[:-1], 1)], dim=-1)
- # Convert keypoints to homogeneous coordinates
- keypoints = torch.cat([
- self.keypoints,
- self.keypoints.new_ones(*self.keypoints.shape[:-1], 1)
- ],
- dim=-1)
- # Transpose corners and keypoints for matrix multiplication
- corners_T = torch.transpose(corners, -1, -2)
- keypoints_T = torch.transpose(keypoints, -1, 0).contiguous().flatten(1)
- # Apply homography matrix to corners and keypoints
- corners_T = torch.matmul(homography_matrix, corners_T)
- keypoints_T = torch.matmul(homography_matrix, keypoints_T)
- # Transpose back to original shape
- corners = torch.transpose(corners_T, -1, -2)
- keypoints_T = keypoints_T.reshape(3, self.keypoints.shape[1], -1)
- keypoints = torch.transpose(keypoints_T, -1, 0).contiguous()
- # Convert corners and keypoints back to non-homogeneous coordinates
- corners = corners[..., :2] / corners[..., 2:3]
- keypoints = keypoints[..., :2] / keypoints[..., 2:3]
- # Convert corners back to bounding boxes and update object attributes
- self.tensor = self.corner2hbox(corners)
- self.keypoints = keypoints
- @classmethod
- def cat(cls: Type[T], box_list: Sequence[T], dim: int = 0) -> T:
- """Cancatenates an instance list into one single instance. Similar to
- ``torch.cat``.
- Args:
- box_list (Sequence[T]): A sequence of instances.
- dim (int): The dimension over which the box and keypoint are
- concatenated. Defaults to 0.
- Returns:
- T: Concatenated instance.
- """
- assert isinstance(box_list, Sequence)
- if len(box_list) == 0:
- raise ValueError('box_list should not be a empty list.')
- assert dim == 0
- assert all(isinstance(boxes, cls) for boxes in box_list)
- th_box_list = torch.cat([boxes.tensor for boxes in box_list], dim=dim)
- th_kpt_list = torch.cat([boxes.keypoints for boxes in box_list],
- dim=dim)
- th_kpt_vis_list = torch.cat(
- [boxes.keypoints_visible for boxes in box_list], dim=dim)
- flip_indices = box_list[0].flip_indices
- return cls(
- th_box_list,
- th_kpt_list,
- th_kpt_vis_list,
- clone=False,
- flip_indices=flip_indices)
- def __getitem__(self: T, index: IndexType) -> T:
- """Rewrite getitem to protect the last dimension shape."""
- boxes = self.tensor
- if isinstance(index, np.ndarray):
- index = torch.as_tensor(index, device=self.device)
- if isinstance(index, Tensor) and index.dtype == torch.bool:
- assert index.dim() < boxes.dim()
- elif isinstance(index, tuple):
- assert len(index) < boxes.dim()
- # `Ellipsis`(...) is commonly used in index like [None, ...].
- # When `Ellipsis` is in index, it must be the last item.
- if Ellipsis in index:
- assert index[-1] is Ellipsis
- boxes = boxes[index]
- keypoints = self.keypoints[index]
- keypoints_visible = self.keypoints_visible[index]
- if boxes.dim() == 1:
- boxes = boxes.reshape(1, -1)
- keypoints = keypoints.reshape(1, -1, 2)
- keypoints_visible = keypoints_visible.reshape(1, -1)
- return type(self)(
- boxes,
- keypoints,
- keypoints_visible,
- flip_indices=self.flip_indices,
- clone=False)
- @property
- def num_keypoints(self) -> Tensor:
- """Compute the number of visible keypoints for each object."""
- return self.keypoints_visible.sum(dim=1).int()
- def __deepcopy__(self, memo):
- """Only clone the tensors when applying deepcopy."""
- cls = self.__class__
- other = cls.__new__(cls)
- memo[id(self)] = other
- other.tensor = self.tensor.clone()
- other.keypoints = self.keypoints.clone()
- other.keypoints_visible = self.keypoints_visible.clone()
- other.flip_indices = deepcopy(self.flip_indices)
- return other
- def clone(self: T) -> T:
- """Reload ``clone`` for tensors."""
- return type(self)(
- self.tensor,
- self.keypoints,
- self.keypoints_visible,
- flip_indices=self.flip_indices,
- clone=True)
- def to(self: T, *args, **kwargs) -> T:
- """Reload ``to`` for tensors."""
- return type(self)(
- self.tensor.to(*args, **kwargs),
- self.keypoints.to(*args, **kwargs),
- self.keypoints_visible.to(*args, **kwargs),
- flip_indices=self.flip_indices,
- clone=False)
|