bbox_overlaps.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. def fp16_clamp(x, min=None, max=None):
  4. if not x.is_cuda and x.dtype == torch.float16:
  5. # clamp for cpu float16, tensor fp16 has no clamp implementation
  6. return x.float().clamp(min, max).half()
  7. return x.clamp(min, max)
  8. def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6):
  9. """Calculate overlap between two set of bboxes.
  10. FP16 Contributed by https://github.com/open-mmlab/mmdetection/pull/4889
  11. Note:
  12. Assume bboxes1 is M x 4, bboxes2 is N x 4, when mode is 'iou',
  13. there are some new generated variable when calculating IOU
  14. using bbox_overlaps function:
  15. 1) is_aligned is False
  16. area1: M x 1
  17. area2: N x 1
  18. lt: M x N x 2
  19. rb: M x N x 2
  20. wh: M x N x 2
  21. overlap: M x N x 1
  22. union: M x N x 1
  23. ious: M x N x 1
  24. Total memory:
  25. S = (9 x N x M + N + M) * 4 Byte,
  26. When using FP16, we can reduce:
  27. R = (9 x N x M + N + M) * 4 / 2 Byte
  28. R large than (N + M) * 4 * 2 is always true when N and M >= 1.
  29. Obviously, N + M <= N * M < 3 * N * M, when N >=2 and M >=2,
  30. N + 1 < 3 * N, when N or M is 1.
  31. Given M = 40 (ground truth), N = 400000 (three anchor boxes
  32. in per grid, FPN, R-CNNs),
  33. R = 275 MB (one times)
  34. A special case (dense detection), M = 512 (ground truth),
  35. R = 3516 MB = 3.43 GB
  36. When the batch size is B, reduce:
  37. B x R
  38. Therefore, CUDA memory runs out frequently.
  39. Experiments on GeForce RTX 2080Ti (11019 MiB):
  40. | dtype | M | N | Use | Real | Ideal |
  41. |:----:|:----:|:----:|:----:|:----:|:----:|
  42. | FP32 | 512 | 400000 | 8020 MiB | -- | -- |
  43. | FP16 | 512 | 400000 | 4504 MiB | 3516 MiB | 3516 MiB |
  44. | FP32 | 40 | 400000 | 1540 MiB | -- | -- |
  45. | FP16 | 40 | 400000 | 1264 MiB | 276MiB | 275 MiB |
  46. 2) is_aligned is True
  47. area1: N x 1
  48. area2: N x 1
  49. lt: N x 2
  50. rb: N x 2
  51. wh: N x 2
  52. overlap: N x 1
  53. union: N x 1
  54. ious: N x 1
  55. Total memory:
  56. S = 11 x N * 4 Byte
  57. When using FP16, we can reduce:
  58. R = 11 x N * 4 / 2 Byte
  59. So do the 'giou' (large than 'iou').
  60. Time-wise, FP16 is generally faster than FP32.
  61. When gpu_assign_thr is not -1, it takes more time on cpu
  62. but not reduce memory.
  63. There, we can reduce half the memory and keep the speed.
  64. If ``is_aligned`` is ``False``, then calculate the overlaps between each
  65. bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned
  66. pair of bboxes1 and bboxes2.
  67. Args:
  68. bboxes1 (Tensor): shape (B, m, 4) in <x1, y1, x2, y2> format or empty.
  69. bboxes2 (Tensor): shape (B, n, 4) in <x1, y1, x2, y2> format or empty.
  70. B indicates the batch dim, in shape (B1, B2, ..., Bn).
  71. If ``is_aligned`` is ``True``, then m and n must be equal.
  72. mode (str): "iou" (intersection over union), "iof" (intersection over
  73. foreground) or "giou" (generalized intersection over union).
  74. Default "iou".
  75. is_aligned (bool, optional): If True, then m and n must be equal.
  76. Default False.
  77. eps (float, optional): A value added to the denominator for numerical
  78. stability. Default 1e-6.
  79. Returns:
  80. Tensor: shape (m, n) if ``is_aligned`` is False else shape (m,)
  81. Example:
  82. >>> bboxes1 = torch.FloatTensor([
  83. >>> [0, 0, 10, 10],
  84. >>> [10, 10, 20, 20],
  85. >>> [32, 32, 38, 42],
  86. >>> ])
  87. >>> bboxes2 = torch.FloatTensor([
  88. >>> [0, 0, 10, 20],
  89. >>> [0, 10, 10, 19],
  90. >>> [10, 10, 20, 20],
  91. >>> ])
  92. >>> overlaps = bbox_overlaps(bboxes1, bboxes2)
  93. >>> assert overlaps.shape == (3, 3)
  94. >>> overlaps = bbox_overlaps(bboxes1, bboxes2, is_aligned=True)
  95. >>> assert overlaps.shape == (3, )
  96. Example:
  97. >>> empty = torch.empty(0, 4)
  98. >>> nonempty = torch.FloatTensor([[0, 0, 10, 9]])
  99. >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1)
  100. >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0)
  101. >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0)
  102. """
  103. assert mode in ['iou', 'iof', 'giou'], f'Unsupported mode {mode}'
  104. # Either the boxes are empty or the length of boxes' last dimension is 4
  105. assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0)
  106. assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0)
  107. # Batch dim must be the same
  108. # Batch dim: (B1, B2, ... Bn)
  109. assert bboxes1.shape[:-2] == bboxes2.shape[:-2]
  110. batch_shape = bboxes1.shape[:-2]
  111. rows = bboxes1.size(-2)
  112. cols = bboxes2.size(-2)
  113. if is_aligned:
  114. assert rows == cols
  115. if rows * cols == 0:
  116. if is_aligned:
  117. return bboxes1.new(batch_shape + (rows, ))
  118. else:
  119. return bboxes1.new(batch_shape + (rows, cols))
  120. area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (
  121. bboxes1[..., 3] - bboxes1[..., 1])
  122. area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (
  123. bboxes2[..., 3] - bboxes2[..., 1])
  124. if is_aligned:
  125. lt = torch.max(bboxes1[..., :2], bboxes2[..., :2]) # [B, rows, 2]
  126. rb = torch.min(bboxes1[..., 2:], bboxes2[..., 2:]) # [B, rows, 2]
  127. wh = fp16_clamp(rb - lt, min=0)
  128. overlap = wh[..., 0] * wh[..., 1]
  129. if mode in ['iou', 'giou']:
  130. union = area1 + area2 - overlap
  131. else:
  132. union = area1
  133. if mode == 'giou':
  134. enclosed_lt = torch.min(bboxes1[..., :2], bboxes2[..., :2])
  135. enclosed_rb = torch.max(bboxes1[..., 2:], bboxes2[..., 2:])
  136. else:
  137. lt = torch.max(bboxes1[..., :, None, :2],
  138. bboxes2[..., None, :, :2]) # [B, rows, cols, 2]
  139. rb = torch.min(bboxes1[..., :, None, 2:],
  140. bboxes2[..., None, :, 2:]) # [B, rows, cols, 2]
  141. wh = fp16_clamp(rb - lt, min=0)
  142. overlap = wh[..., 0] * wh[..., 1]
  143. if mode in ['iou', 'giou']:
  144. union = area1[..., None] + area2[..., None, :] - overlap
  145. else:
  146. union = area1[..., None]
  147. if mode == 'giou':
  148. enclosed_lt = torch.min(bboxes1[..., :, None, :2],
  149. bboxes2[..., None, :, :2])
  150. enclosed_rb = torch.max(bboxes1[..., :, None, 2:],
  151. bboxes2[..., None, :, 2:])
  152. eps = union.new_tensor([eps])
  153. union = torch.max(union, eps)
  154. ious = overlap / union
  155. if mode in ['iou', 'iof']:
  156. return ious
  157. # calculate gious
  158. enclose_wh = fp16_clamp(enclosed_rb - enclosed_lt, min=0)
  159. enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1]
  160. enclose_area = torch.max(enclose_area, eps)
  161. gious = ious - (enclose_area - union) / enclose_area
  162. return gious