bbox_nms.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Tuple, Union
  3. import torch
  4. from mmcv.ops.nms import batched_nms
  5. from torch import Tensor
  6. from mmdet.structures.bbox import bbox_overlaps
  7. from mmdet.utils import ConfigType
  8. def multiclass_nms(
  9. multi_bboxes: Tensor,
  10. multi_scores: Tensor,
  11. score_thr: float,
  12. nms_cfg: ConfigType,
  13. max_num: int = -1,
  14. score_factors: Optional[Tensor] = None,
  15. return_inds: bool = False,
  16. box_dim: int = 4
  17. ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]:
  18. """NMS for multi-class bboxes.
  19. Args:
  20. multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
  21. multi_scores (Tensor): shape (n, #class), where the last column
  22. contains scores of the background class, but this will be ignored.
  23. score_thr (float): bbox threshold, bboxes with scores lower than it
  24. will not be considered.
  25. nms_cfg (Union[:obj:`ConfigDict`, dict]): a dict that contains
  26. the arguments of nms operations.
  27. max_num (int, optional): if there are more than max_num bboxes after
  28. NMS, only top max_num will be kept. Default to -1.
  29. score_factors (Tensor, optional): The factors multiplied to scores
  30. before applying NMS. Default to None.
  31. return_inds (bool, optional): Whether return the indices of kept
  32. bboxes. Default to False.
  33. box_dim (int): The dimension of boxes. Defaults to 4.
  34. Returns:
  35. Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]:
  36. (dets, labels, indices (optional)), tensors of shape (k, 5),
  37. (k), and (k). Dets are boxes with scores. Labels are 0-based.
  38. """
  39. num_classes = multi_scores.size(1) - 1
  40. # exclude background category
  41. if multi_bboxes.shape[1] > box_dim:
  42. bboxes = multi_bboxes.view(multi_scores.size(0), -1, box_dim)
  43. else:
  44. bboxes = multi_bboxes[:, None].expand(
  45. multi_scores.size(0), num_classes, box_dim)
  46. scores = multi_scores[:, :-1]
  47. labels = torch.arange(num_classes, dtype=torch.long, device=scores.device)
  48. labels = labels.view(1, -1).expand_as(scores)
  49. bboxes = bboxes.reshape(-1, box_dim)
  50. scores = scores.reshape(-1)
  51. labels = labels.reshape(-1)
  52. if not torch.onnx.is_in_onnx_export():
  53. # NonZero not supported in TensorRT
  54. # remove low scoring boxes
  55. valid_mask = scores > score_thr
  56. # multiply score_factor after threshold to preserve more bboxes, improve
  57. # mAP by 1% for YOLOv3
  58. if score_factors is not None:
  59. # expand the shape to match original shape of score
  60. score_factors = score_factors.view(-1, 1).expand(
  61. multi_scores.size(0), num_classes)
  62. score_factors = score_factors.reshape(-1)
  63. scores = scores * score_factors
  64. if not torch.onnx.is_in_onnx_export():
  65. # NonZero not supported in TensorRT
  66. inds = valid_mask.nonzero(as_tuple=False).squeeze(1)
  67. bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds]
  68. else:
  69. # TensorRT NMS plugin has invalid output filled with -1
  70. # add dummy data to make detection output correct.
  71. bboxes = torch.cat([bboxes, bboxes.new_zeros(1, box_dim)], dim=0)
  72. scores = torch.cat([scores, scores.new_zeros(1)], dim=0)
  73. labels = torch.cat([labels, labels.new_zeros(1)], dim=0)
  74. if bboxes.numel() == 0:
  75. if torch.onnx.is_in_onnx_export():
  76. raise RuntimeError('[ONNX Error] Can not record NMS '
  77. 'as it has not been executed this time')
  78. dets = torch.cat([bboxes, scores[:, None]], -1)
  79. if return_inds:
  80. return dets, labels, inds
  81. else:
  82. return dets, labels
  83. dets, keep = batched_nms(bboxes, scores, labels, nms_cfg)
  84. if max_num > 0:
  85. dets = dets[:max_num]
  86. keep = keep[:max_num]
  87. if return_inds:
  88. return dets, labels[keep], inds[keep]
  89. else:
  90. return dets, labels[keep]
  91. def fast_nms(
  92. multi_bboxes: Tensor,
  93. multi_scores: Tensor,
  94. multi_coeffs: Tensor,
  95. score_thr: float,
  96. iou_thr: float,
  97. top_k: int,
  98. max_num: int = -1
  99. ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]:
  100. """Fast NMS in `YOLACT <https://arxiv.org/abs/1904.02689>`_.
  101. Fast NMS allows already-removed detections to suppress other detections so
  102. that every instance can be decided to be kept or discarded in parallel,
  103. which is not possible in traditional NMS. This relaxation allows us to
  104. implement Fast NMS entirely in standard GPU-accelerated matrix operations.
  105. Args:
  106. multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
  107. multi_scores (Tensor): shape (n, #class+1), where the last column
  108. contains scores of the background class, but this will be ignored.
  109. multi_coeffs (Tensor): shape (n, #class*coeffs_dim).
  110. score_thr (float): bbox threshold, bboxes with scores lower than it
  111. will not be considered.
  112. iou_thr (float): IoU threshold to be considered as conflicted.
  113. top_k (int): if there are more than top_k bboxes before NMS,
  114. only top top_k will be kept.
  115. max_num (int): if there are more than max_num bboxes after NMS,
  116. only top max_num will be kept. If -1, keep all the bboxes.
  117. Default: -1.
  118. Returns:
  119. Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]:
  120. (dets, labels, coefficients), tensors of shape (k, 5), (k, 1),
  121. and (k, coeffs_dim). Dets are boxes with scores.
  122. Labels are 0-based.
  123. """
  124. scores = multi_scores[:, :-1].t() # [#class, n]
  125. scores, idx = scores.sort(1, descending=True)
  126. idx = idx[:, :top_k].contiguous()
  127. scores = scores[:, :top_k] # [#class, topk]
  128. num_classes, num_dets = idx.size()
  129. boxes = multi_bboxes[idx.view(-1), :].view(num_classes, num_dets, 4)
  130. coeffs = multi_coeffs[idx.view(-1), :].view(num_classes, num_dets, -1)
  131. iou = bbox_overlaps(boxes, boxes) # [#class, topk, topk]
  132. iou.triu_(diagonal=1)
  133. iou_max, _ = iou.max(dim=1)
  134. # Now just filter out the ones higher than the threshold
  135. keep = iou_max <= iou_thr
  136. # Second thresholding introduces 0.2 mAP gain at negligible time cost
  137. keep *= scores > score_thr
  138. # Assign each kept detection to its corresponding class
  139. classes = torch.arange(
  140. num_classes, device=boxes.device)[:, None].expand_as(keep)
  141. classes = classes[keep]
  142. boxes = boxes[keep]
  143. coeffs = coeffs[keep]
  144. scores = scores[keep]
  145. # Only keep the top max_num highest scores across all classes
  146. scores, idx = scores.sort(0, descending=True)
  147. if max_num > 0:
  148. idx = idx[:max_num]
  149. scores = scores[:max_num]
  150. classes = classes[idx]
  151. boxes = boxes[idx]
  152. coeffs = coeffs[idx]
  153. cls_dets = torch.cat([boxes, scores[:, None]], dim=1)
  154. return cls_dets, classes, coeffs