123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Optional, Tuple, Union
- import torch
- from mmcv.ops.nms import batched_nms
- from torch import Tensor
- from mmdet.structures.bbox import bbox_overlaps
- from mmdet.utils import ConfigType
- def multiclass_nms(
- multi_bboxes: Tensor,
- multi_scores: Tensor,
- score_thr: float,
- nms_cfg: ConfigType,
- max_num: int = -1,
- score_factors: Optional[Tensor] = None,
- return_inds: bool = False,
- box_dim: int = 4
- ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]:
- """NMS for multi-class bboxes.
- Args:
- multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
- multi_scores (Tensor): shape (n, #class), where the last column
- contains scores of the background class, but this will be ignored.
- score_thr (float): bbox threshold, bboxes with scores lower than it
- will not be considered.
- nms_cfg (Union[:obj:`ConfigDict`, dict]): a dict that contains
- the arguments of nms operations.
- max_num (int, optional): if there are more than max_num bboxes after
- NMS, only top max_num will be kept. Default to -1.
- score_factors (Tensor, optional): The factors multiplied to scores
- before applying NMS. Default to None.
- return_inds (bool, optional): Whether return the indices of kept
- bboxes. Default to False.
- box_dim (int): The dimension of boxes. Defaults to 4.
- Returns:
- Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]:
- (dets, labels, indices (optional)), tensors of shape (k, 5),
- (k), and (k). Dets are boxes with scores. Labels are 0-based.
- """
- num_classes = multi_scores.size(1) - 1
- # exclude background category
- if multi_bboxes.shape[1] > box_dim:
- bboxes = multi_bboxes.view(multi_scores.size(0), -1, box_dim)
- else:
- bboxes = multi_bboxes[:, None].expand(
- multi_scores.size(0), num_classes, box_dim)
- scores = multi_scores[:, :-1]
- labels = torch.arange(num_classes, dtype=torch.long, device=scores.device)
- labels = labels.view(1, -1).expand_as(scores)
- bboxes = bboxes.reshape(-1, box_dim)
- scores = scores.reshape(-1)
- labels = labels.reshape(-1)
- if not torch.onnx.is_in_onnx_export():
- # NonZero not supported in TensorRT
- # remove low scoring boxes
- valid_mask = scores > score_thr
- # multiply score_factor after threshold to preserve more bboxes, improve
- # mAP by 1% for YOLOv3
- if score_factors is not None:
- # expand the shape to match original shape of score
- score_factors = score_factors.view(-1, 1).expand(
- multi_scores.size(0), num_classes)
- score_factors = score_factors.reshape(-1)
- scores = scores * score_factors
- if not torch.onnx.is_in_onnx_export():
- # NonZero not supported in TensorRT
- inds = valid_mask.nonzero(as_tuple=False).squeeze(1)
- bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds]
- else:
- # TensorRT NMS plugin has invalid output filled with -1
- # add dummy data to make detection output correct.
- bboxes = torch.cat([bboxes, bboxes.new_zeros(1, box_dim)], dim=0)
- scores = torch.cat([scores, scores.new_zeros(1)], dim=0)
- labels = torch.cat([labels, labels.new_zeros(1)], dim=0)
- if bboxes.numel() == 0:
- if torch.onnx.is_in_onnx_export():
- raise RuntimeError('[ONNX Error] Can not record NMS '
- 'as it has not been executed this time')
- dets = torch.cat([bboxes, scores[:, None]], -1)
- if return_inds:
- return dets, labels, inds
- else:
- return dets, labels
- dets, keep = batched_nms(bboxes, scores, labels, nms_cfg)
- if max_num > 0:
- dets = dets[:max_num]
- keep = keep[:max_num]
- if return_inds:
- return dets, labels[keep], inds[keep]
- else:
- return dets, labels[keep]
- def fast_nms(
- multi_bboxes: Tensor,
- multi_scores: Tensor,
- multi_coeffs: Tensor,
- score_thr: float,
- iou_thr: float,
- top_k: int,
- max_num: int = -1
- ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]:
- """Fast NMS in `YOLACT <https://arxiv.org/abs/1904.02689>`_.
- Fast NMS allows already-removed detections to suppress other detections so
- that every instance can be decided to be kept or discarded in parallel,
- which is not possible in traditional NMS. This relaxation allows us to
- implement Fast NMS entirely in standard GPU-accelerated matrix operations.
- Args:
- multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
- multi_scores (Tensor): shape (n, #class+1), where the last column
- contains scores of the background class, but this will be ignored.
- multi_coeffs (Tensor): shape (n, #class*coeffs_dim).
- score_thr (float): bbox threshold, bboxes with scores lower than it
- will not be considered.
- iou_thr (float): IoU threshold to be considered as conflicted.
- top_k (int): if there are more than top_k bboxes before NMS,
- only top top_k will be kept.
- max_num (int): if there are more than max_num bboxes after NMS,
- only top max_num will be kept. If -1, keep all the bboxes.
- Default: -1.
- Returns:
- Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]:
- (dets, labels, coefficients), tensors of shape (k, 5), (k, 1),
- and (k, coeffs_dim). Dets are boxes with scores.
- Labels are 0-based.
- """
- scores = multi_scores[:, :-1].t() # [#class, n]
- scores, idx = scores.sort(1, descending=True)
- idx = idx[:, :top_k].contiguous()
- scores = scores[:, :top_k] # [#class, topk]
- num_classes, num_dets = idx.size()
- boxes = multi_bboxes[idx.view(-1), :].view(num_classes, num_dets, 4)
- coeffs = multi_coeffs[idx.view(-1), :].view(num_classes, num_dets, -1)
- iou = bbox_overlaps(boxes, boxes) # [#class, topk, topk]
- iou.triu_(diagonal=1)
- iou_max, _ = iou.max(dim=1)
- # Now just filter out the ones higher than the threshold
- keep = iou_max <= iou_thr
- # Second thresholding introduces 0.2 mAP gain at negligible time cost
- keep *= scores > score_thr
- # Assign each kept detection to its corresponding class
- classes = torch.arange(
- num_classes, device=boxes.device)[:, None].expand_as(keep)
- classes = classes[keep]
- boxes = boxes[keep]
- coeffs = coeffs[keep]
- scores = scores[keep]
- # Only keep the top max_num highest scores across all classes
- scores, idx = scores.sort(0, descending=True)
- if max_num > 0:
- idx = idx[:max_num]
- scores = scores[:max_num]
- classes = classes[idx]
- boxes = boxes[idx]
- coeffs = coeffs[idx]
- cls_dets = torch.cat([boxes, scores[:, None]], dim=1)
- return cls_dets, classes, coeffs
|