bbox_keypoint_structure.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from copy import deepcopy
  3. from typing import List, Optional, Sequence, Tuple, Type, TypeVar, Union
  4. import numpy as np
  5. import torch
  6. from mmdet.structures.bbox import HorizontalBoxes
  7. from torch import Tensor
  8. DeviceType = Union[str, torch.device]
  9. T = TypeVar('T')
  10. IndexType = Union[slice, int, list, torch.LongTensor, torch.cuda.LongTensor,
  11. torch.BoolTensor, torch.cuda.BoolTensor, np.ndarray]
  12. class BBoxKeypoints(HorizontalBoxes):
  13. """The BBoxKeypoints class is a combination of bounding boxes and keypoints
  14. representation. The box format used in BBoxKeypoints is the same as
  15. HorizontalBoxes.
  16. Args:
  17. data (Tensor or np.ndarray): The box data with shape of
  18. (N, 4).
  19. keypoints (Tensor or np.ndarray): The keypoint data with shape of
  20. (N, K, 2).
  21. keypoints_visible (Tensor or np.ndarray): The visibility of keypoints
  22. with shape of (N, K).
  23. dtype (torch.dtype, Optional): data type of boxes. Defaults to None.
  24. device (str or torch.device, Optional): device of boxes.
  25. Default to None.
  26. clone (bool): Whether clone ``boxes`` or not. Defaults to True.
  27. mode (str, Optional): the mode of boxes. If it is 'cxcywh', the
  28. `data` will be converted to 'xyxy' mode. Defaults to None.
  29. flip_indices (list, Optional): The indices of keypoints when the
  30. images is flipped. Defaults to None.
  31. Notes:
  32. N: the number of instances.
  33. K: the number of keypoints.
  34. """
  35. def __init__(self,
  36. data: Union[Tensor, np.ndarray],
  37. keypoints: Union[Tensor, np.ndarray],
  38. keypoints_visible: Union[Tensor, np.ndarray],
  39. dtype: Optional[torch.dtype] = None,
  40. device: Optional[DeviceType] = None,
  41. clone: bool = True,
  42. in_mode: Optional[str] = None,
  43. flip_indices: Optional[List] = None) -> None:
  44. super().__init__(
  45. data=data,
  46. dtype=dtype,
  47. device=device,
  48. clone=clone,
  49. in_mode=in_mode)
  50. assert len(data) == len(keypoints)
  51. assert len(data) == len(keypoints_visible)
  52. assert keypoints.ndim == 3
  53. assert keypoints_visible.ndim == 2
  54. keypoints = torch.as_tensor(keypoints)
  55. keypoints_visible = torch.as_tensor(keypoints_visible)
  56. if device is not None:
  57. keypoints = keypoints.to(device=device)
  58. keypoints_visible = keypoints_visible.to(device=device)
  59. if clone:
  60. keypoints = keypoints.clone()
  61. keypoints_visible = keypoints_visible.clone()
  62. self.keypoints = keypoints
  63. self.keypoints_visible = keypoints_visible
  64. self.flip_indices = flip_indices
  65. def flip_(self,
  66. img_shape: Tuple[int, int],
  67. direction: str = 'horizontal') -> None:
  68. """Flip boxes & kpts horizontally in-place.
  69. Args:
  70. img_shape (Tuple[int, int]): A tuple of image height and width.
  71. direction (str): Flip direction, options are "horizontal",
  72. "vertical" and "diagonal". Defaults to "horizontal"
  73. """
  74. assert direction == 'horizontal'
  75. super().flip_(img_shape, direction)
  76. self.keypoints[..., 0] = img_shape[1] - self.keypoints[..., 0]
  77. self.keypoints = self.keypoints[:, self.flip_indices]
  78. self.keypoints_visible = self.keypoints_visible[:, self.flip_indices]
  79. def translate_(self, distances: Tuple[float, float]) -> None:
  80. """Translate boxes and keypoints in-place.
  81. Args:
  82. distances (Tuple[float, float]): translate distances. The first
  83. is horizontal distance and the second is vertical distance.
  84. """
  85. boxes = self.tensor
  86. assert len(distances) == 2
  87. self.tensor = boxes + boxes.new_tensor(distances).repeat(2)
  88. distances = self.keypoints.new_tensor(distances).reshape(1, 1, 2)
  89. self.keypoints = self.keypoints + distances
  90. def rescale_(self, scale_factor: Tuple[float, float]) -> None:
  91. """Rescale boxes & keypoints w.r.t. rescale_factor in-place.
  92. Note:
  93. Both ``rescale_`` and ``resize_`` will enlarge or shrink boxes
  94. w.r.t ``scale_facotr``. The difference is that ``resize_`` only
  95. changes the width and the height of boxes, but ``rescale_`` also
  96. rescales the box centers simultaneously.
  97. Args:
  98. scale_factor (Tuple[float, float]): factors for scaling boxes.
  99. The length should be 2.
  100. """
  101. boxes = self.tensor
  102. assert len(scale_factor) == 2
  103. self.tensor = boxes * boxes.new_tensor(scale_factor).repeat(2)
  104. scale_factor = self.keypoints.new_tensor(scale_factor).reshape(1, 1, 2)
  105. self.keypoints = self.keypoints * scale_factor
  106. def clip_(self, img_shape: Tuple[int, int]) -> None:
  107. """Clip bounding boxes and set invisible keypoints outside the image
  108. boundary in-place.
  109. Args:
  110. img_shape (Tuple[int, int]): A tuple of image height and width.
  111. """
  112. boxes = self.tensor
  113. boxes[..., 0::2] = boxes[..., 0::2].clamp(0, img_shape[1])
  114. boxes[..., 1::2] = boxes[..., 1::2].clamp(0, img_shape[0])
  115. kpt_outside = torch.logical_or(
  116. torch.logical_or(self.keypoints[..., 0] < 0,
  117. self.keypoints[..., 1] < 0),
  118. torch.logical_or(self.keypoints[..., 0] > img_shape[1],
  119. self.keypoints[..., 1] > img_shape[0]))
  120. self.keypoints_visible[kpt_outside] *= 0
  121. def project_(self, homography_matrix: Union[Tensor, np.ndarray]) -> None:
  122. """Geometrically transform bounding boxes and keypoints in-place using
  123. a homography matrix.
  124. Args:
  125. homography_matrix (Tensor or np.ndarray): A 3x3 tensor or ndarray
  126. representing the homography matrix for the transformation.
  127. """
  128. boxes = self.tensor
  129. if isinstance(homography_matrix, np.ndarray):
  130. homography_matrix = boxes.new_tensor(homography_matrix)
  131. # Convert boxes to corners in homogeneous coordinates
  132. corners = self.hbox2corner(boxes)
  133. corners = torch.cat(
  134. [corners, corners.new_ones(*corners.shape[:-1], 1)], dim=-1)
  135. # Convert keypoints to homogeneous coordinates
  136. keypoints = torch.cat([
  137. self.keypoints,
  138. self.keypoints.new_ones(*self.keypoints.shape[:-1], 1)
  139. ],
  140. dim=-1)
  141. # Transpose corners and keypoints for matrix multiplication
  142. corners_T = torch.transpose(corners, -1, -2)
  143. keypoints_T = torch.transpose(keypoints, -1, 0).contiguous().flatten(1)
  144. # Apply homography matrix to corners and keypoints
  145. corners_T = torch.matmul(homography_matrix, corners_T)
  146. keypoints_T = torch.matmul(homography_matrix, keypoints_T)
  147. # Transpose back to original shape
  148. corners = torch.transpose(corners_T, -1, -2)
  149. keypoints_T = keypoints_T.reshape(3, self.keypoints.shape[1], -1)
  150. keypoints = torch.transpose(keypoints_T, -1, 0).contiguous()
  151. # Convert corners and keypoints back to non-homogeneous coordinates
  152. corners = corners[..., :2] / corners[..., 2:3]
  153. keypoints = keypoints[..., :2] / keypoints[..., 2:3]
  154. # Convert corners back to bounding boxes and update object attributes
  155. self.tensor = self.corner2hbox(corners)
  156. self.keypoints = keypoints
  157. @classmethod
  158. def cat(cls: Type[T], box_list: Sequence[T], dim: int = 0) -> T:
  159. """Cancatenates an instance list into one single instance. Similar to
  160. ``torch.cat``.
  161. Args:
  162. box_list (Sequence[T]): A sequence of instances.
  163. dim (int): The dimension over which the box and keypoint are
  164. concatenated. Defaults to 0.
  165. Returns:
  166. T: Concatenated instance.
  167. """
  168. assert isinstance(box_list, Sequence)
  169. if len(box_list) == 0:
  170. raise ValueError('box_list should not be a empty list.')
  171. assert dim == 0
  172. assert all(isinstance(boxes, cls) for boxes in box_list)
  173. th_box_list = torch.cat([boxes.tensor for boxes in box_list], dim=dim)
  174. th_kpt_list = torch.cat([boxes.keypoints for boxes in box_list],
  175. dim=dim)
  176. th_kpt_vis_list = torch.cat(
  177. [boxes.keypoints_visible for boxes in box_list], dim=dim)
  178. flip_indices = box_list[0].flip_indices
  179. return cls(
  180. th_box_list,
  181. th_kpt_list,
  182. th_kpt_vis_list,
  183. clone=False,
  184. flip_indices=flip_indices)
  185. def __getitem__(self: T, index: IndexType) -> T:
  186. """Rewrite getitem to protect the last dimension shape."""
  187. boxes = self.tensor
  188. if isinstance(index, np.ndarray):
  189. index = torch.as_tensor(index, device=self.device)
  190. if isinstance(index, Tensor) and index.dtype == torch.bool:
  191. assert index.dim() < boxes.dim()
  192. elif isinstance(index, tuple):
  193. assert len(index) < boxes.dim()
  194. # `Ellipsis`(...) is commonly used in index like [None, ...].
  195. # When `Ellipsis` is in index, it must be the last item.
  196. if Ellipsis in index:
  197. assert index[-1] is Ellipsis
  198. boxes = boxes[index]
  199. keypoints = self.keypoints[index]
  200. keypoints_visible = self.keypoints_visible[index]
  201. if boxes.dim() == 1:
  202. boxes = boxes.reshape(1, -1)
  203. keypoints = keypoints.reshape(1, -1, 2)
  204. keypoints_visible = keypoints_visible.reshape(1, -1)
  205. return type(self)(
  206. boxes,
  207. keypoints,
  208. keypoints_visible,
  209. flip_indices=self.flip_indices,
  210. clone=False)
  211. @property
  212. def num_keypoints(self) -> Tensor:
  213. """Compute the number of visible keypoints for each object."""
  214. return self.keypoints_visible.sum(dim=1).int()
  215. def __deepcopy__(self, memo):
  216. """Only clone the tensors when applying deepcopy."""
  217. cls = self.__class__
  218. other = cls.__new__(cls)
  219. memo[id(self)] = other
  220. other.tensor = self.tensor.clone()
  221. other.keypoints = self.keypoints.clone()
  222. other.keypoints_visible = self.keypoints_visible.clone()
  223. other.flip_indices = deepcopy(self.flip_indices)
  224. return other
  225. def clone(self: T) -> T:
  226. """Reload ``clone`` for tensors."""
  227. return type(self)(
  228. self.tensor,
  229. self.keypoints,
  230. self.keypoints_visible,
  231. flip_indices=self.flip_indices,
  232. clone=True)
  233. def to(self: T, *args, **kwargs) -> T:
  234. """Reload ``to`` for tensors."""
  235. return type(self)(
  236. self.tensor.to(*args, **kwargs),
  237. self.keypoints.to(*args, **kwargs),
  238. self.keypoints_visible.to(*args, **kwargs),
  239. flip_indices=self.flip_indices,
  240. clone=False)