gaussian_target.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from math import sqrt
  3. import torch
  4. import torch.nn.functional as F
  5. def gaussian2D(radius, sigma=1, dtype=torch.float32, device='cpu'):
  6. """Generate 2D gaussian kernel.
  7. Args:
  8. radius (int): Radius of gaussian kernel.
  9. sigma (int): Sigma of gaussian function. Default: 1.
  10. dtype (torch.dtype): Dtype of gaussian tensor. Default: torch.float32.
  11. device (str): Device of gaussian tensor. Default: 'cpu'.
  12. Returns:
  13. h (Tensor): Gaussian kernel with a
  14. ``(2 * radius + 1) * (2 * radius + 1)`` shape.
  15. """
  16. x = torch.arange(
  17. -radius, radius + 1, dtype=dtype, device=device).view(1, -1)
  18. y = torch.arange(
  19. -radius, radius + 1, dtype=dtype, device=device).view(-1, 1)
  20. h = (-(x * x + y * y) / (2 * sigma * sigma)).exp()
  21. h[h < torch.finfo(h.dtype).eps * h.max()] = 0
  22. return h
  23. def gen_gaussian_target(heatmap, center, radius, k=1):
  24. """Generate 2D gaussian heatmap.
  25. Args:
  26. heatmap (Tensor): Input heatmap, the gaussian kernel will cover on
  27. it and maintain the max value.
  28. center (list[int]): Coord of gaussian kernel's center.
  29. radius (int): Radius of gaussian kernel.
  30. k (int): Coefficient of gaussian kernel. Default: 1.
  31. Returns:
  32. out_heatmap (Tensor): Updated heatmap covered by gaussian kernel.
  33. """
  34. diameter = 2 * radius + 1
  35. gaussian_kernel = gaussian2D(
  36. radius, sigma=diameter / 6, dtype=heatmap.dtype, device=heatmap.device)
  37. x, y = center
  38. height, width = heatmap.shape[:2]
  39. left, right = min(x, radius), min(width - x, radius + 1)
  40. top, bottom = min(y, radius), min(height - y, radius + 1)
  41. masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
  42. masked_gaussian = gaussian_kernel[radius - top:radius + bottom,
  43. radius - left:radius + right]
  44. out_heatmap = heatmap
  45. torch.max(
  46. masked_heatmap,
  47. masked_gaussian * k,
  48. out=out_heatmap[y - top:y + bottom, x - left:x + right])
  49. return out_heatmap
  50. def gaussian_radius(det_size, min_overlap):
  51. r"""Generate 2D gaussian radius.
  52. This function is modified from the `official github repo
  53. <https://github.com/princeton-vl/CornerNet-Lite/blob/master/core/sample/
  54. utils.py#L65>`_.
  55. Given ``min_overlap``, radius could computed by a quadratic equation
  56. according to Vieta's formulas.
  57. There are 3 cases for computing gaussian radius, details are following:
  58. - Explanation of figure: ``lt`` and ``br`` indicates the left-top and
  59. bottom-right corner of ground truth box. ``x`` indicates the
  60. generated corner at the limited position when ``radius=r``.
  61. - Case1: one corner is inside the gt box and the other is outside.
  62. .. code:: text
  63. |< width >|
  64. lt-+----------+ -
  65. | | | ^
  66. +--x----------+--+
  67. | | | |
  68. | | | | height
  69. | | overlap | |
  70. | | | |
  71. | | | | v
  72. +--+---------br--+ -
  73. | | |
  74. +----------+--x
  75. To ensure IoU of generated box and gt box is larger than ``min_overlap``:
  76. .. math::
  77. \cfrac{(w-r)*(h-r)}{w*h+(w+h)r-r^2} \ge {iou} \quad\Rightarrow\quad
  78. {r^2-(w+h)r+\cfrac{1-iou}{1+iou}*w*h} \ge 0 \\
  79. {a} = 1,\quad{b} = {-(w+h)},\quad{c} = {\cfrac{1-iou}{1+iou}*w*h}
  80. {r} \le \cfrac{-b-\sqrt{b^2-4*a*c}}{2*a}
  81. - Case2: both two corners are inside the gt box.
  82. .. code:: text
  83. |< width >|
  84. lt-+----------+ -
  85. | | | ^
  86. +--x-------+ |
  87. | | | |
  88. | |overlap| | height
  89. | | | |
  90. | +-------x--+
  91. | | | v
  92. +----------+-br -
  93. To ensure IoU of generated box and gt box is larger than ``min_overlap``:
  94. .. math::
  95. \cfrac{(w-2*r)*(h-2*r)}{w*h} \ge {iou} \quad\Rightarrow\quad
  96. {4r^2-2(w+h)r+(1-iou)*w*h} \ge 0 \\
  97. {a} = 4,\quad {b} = {-2(w+h)},\quad {c} = {(1-iou)*w*h}
  98. {r} \le \cfrac{-b-\sqrt{b^2-4*a*c}}{2*a}
  99. - Case3: both two corners are outside the gt box.
  100. .. code:: text
  101. |< width >|
  102. x--+----------------+
  103. | | |
  104. +-lt-------------+ | -
  105. | | | | ^
  106. | | | |
  107. | | overlap | | height
  108. | | | |
  109. | | | | v
  110. | +------------br--+ -
  111. | | |
  112. +----------------+--x
  113. To ensure IoU of generated box and gt box is larger than ``min_overlap``:
  114. .. math::
  115. \cfrac{w*h}{(w+2*r)*(h+2*r)} \ge {iou} \quad\Rightarrow\quad
  116. {4*iou*r^2+2*iou*(w+h)r+(iou-1)*w*h} \le 0 \\
  117. {a} = {4*iou},\quad {b} = {2*iou*(w+h)},\quad {c} = {(iou-1)*w*h} \\
  118. {r} \le \cfrac{-b+\sqrt{b^2-4*a*c}}{2*a}
  119. Args:
  120. det_size (list[int]): Shape of object.
  121. min_overlap (float): Min IoU with ground truth for boxes generated by
  122. keypoints inside the gaussian kernel.
  123. Returns:
  124. radius (int): Radius of gaussian kernel.
  125. """
  126. height, width = det_size
  127. a1 = 1
  128. b1 = (height + width)
  129. c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
  130. sq1 = sqrt(b1**2 - 4 * a1 * c1)
  131. r1 = (b1 - sq1) / (2 * a1)
  132. a2 = 4
  133. b2 = 2 * (height + width)
  134. c2 = (1 - min_overlap) * width * height
  135. sq2 = sqrt(b2**2 - 4 * a2 * c2)
  136. r2 = (b2 - sq2) / (2 * a2)
  137. a3 = 4 * min_overlap
  138. b3 = -2 * min_overlap * (height + width)
  139. c3 = (min_overlap - 1) * width * height
  140. sq3 = sqrt(b3**2 - 4 * a3 * c3)
  141. r3 = (b3 + sq3) / (2 * a3)
  142. return min(r1, r2, r3)
  143. def get_local_maximum(heat, kernel=3):
  144. """Extract local maximum pixel with given kernel.
  145. Args:
  146. heat (Tensor): Target heatmap.
  147. kernel (int): Kernel size of max pooling. Default: 3.
  148. Returns:
  149. heat (Tensor): A heatmap where local maximum pixels maintain its
  150. own value and other positions are 0.
  151. """
  152. pad = (kernel - 1) // 2
  153. hmax = F.max_pool2d(heat, kernel, stride=1, padding=pad)
  154. keep = (hmax == heat).float()
  155. return heat * keep
  156. def get_topk_from_heatmap(scores, k=20):
  157. """Get top k positions from heatmap.
  158. Args:
  159. scores (Tensor): Target heatmap with shape
  160. [batch, num_classes, height, width].
  161. k (int): Target number. Default: 20.
  162. Returns:
  163. tuple[torch.Tensor]: Scores, indexes, categories and coords of
  164. topk keypoint. Containing following Tensors:
  165. - topk_scores (Tensor): Max scores of each topk keypoint.
  166. - topk_inds (Tensor): Indexes of each topk keypoint.
  167. - topk_clses (Tensor): Categories of each topk keypoint.
  168. - topk_ys (Tensor): Y-coord of each topk keypoint.
  169. - topk_xs (Tensor): X-coord of each topk keypoint.
  170. """
  171. batch, _, height, width = scores.size()
  172. topk_scores, topk_inds = torch.topk(scores.view(batch, -1), k)
  173. topk_clses = topk_inds // (height * width)
  174. topk_inds = topk_inds % (height * width)
  175. topk_ys = topk_inds // width
  176. topk_xs = (topk_inds % width).int().float()
  177. return topk_scores, topk_inds, topk_clses, topk_ys, topk_xs
  178. def gather_feat(feat, ind, mask=None):
  179. """Gather feature according to index.
  180. Args:
  181. feat (Tensor): Target feature map.
  182. ind (Tensor): Target coord index.
  183. mask (Tensor | None): Mask of feature map. Default: None.
  184. Returns:
  185. feat (Tensor): Gathered feature.
  186. """
  187. dim = feat.size(2)
  188. ind = ind.unsqueeze(2).repeat(1, 1, dim)
  189. feat = feat.gather(1, ind)
  190. if mask is not None:
  191. mask = mask.unsqueeze(2).expand_as(feat)
  192. feat = feat[mask]
  193. feat = feat.view(-1, dim)
  194. return feat
  195. def transpose_and_gather_feat(feat, ind):
  196. """Transpose and gather feature according to index.
  197. Args:
  198. feat (Tensor): Target feature map.
  199. ind (Tensor): Target coord index.
  200. Returns:
  201. feat (Tensor): Transposed and gathered feature.
  202. """
  203. feat = feat.permute(0, 2, 3, 1).contiguous()
  204. feat = feat.view(feat.size(0), -1, feat.size(3))
  205. feat = gather_feat(feat, ind)
  206. return feat