123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Optional
- import torch.nn as nn
- import torch.nn.functional as F
- from torch import Tensor
- from mmdet.registry import MODELS
- from .utils import weight_reduce_loss
- def varifocal_loss(pred: Tensor,
- target: Tensor,
- weight: Optional[Tensor] = None,
- alpha: float = 0.75,
- gamma: float = 2.0,
- iou_weighted: bool = True,
- reduction: str = 'mean',
- avg_factor: Optional[int] = None) -> Tensor:
- """`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_
- Args:
- pred (Tensor): The prediction with shape (N, C), C is the
- number of classes.
- target (Tensor): The learning target of the iou-aware
- classification score with shape (N, C), C is the number of classes.
- weight (Tensor, optional): The weight of loss for each
- prediction. Defaults to None.
- alpha (float, optional): A balance factor for the negative part of
- Varifocal Loss, which is different from the alpha of Focal Loss.
- Defaults to 0.75.
- gamma (float, optional): The gamma for calculating the modulating
- factor. Defaults to 2.0.
- iou_weighted (bool, optional): Whether to weight the loss of the
- positive example with the iou target. Defaults to True.
- reduction (str, optional): The method used to reduce the loss into
- a scalar. Defaults to 'mean'. Options are "none", "mean" and
- "sum".
- avg_factor (int, optional): Average factor that is used to average
- the loss. Defaults to None.
- Returns:
- Tensor: Loss tensor.
- """
- # pred and target should be of the same size
- assert pred.size() == target.size()
- pred_sigmoid = pred.sigmoid()
- target = target.type_as(pred)
- if iou_weighted:
- focal_weight = target * (target > 0.0).float() + \
- alpha * (pred_sigmoid - target).abs().pow(gamma) * \
- (target <= 0.0).float()
- else:
- focal_weight = (target > 0.0).float() + \
- alpha * (pred_sigmoid - target).abs().pow(gamma) * \
- (target <= 0.0).float()
- loss = F.binary_cross_entropy_with_logits(
- pred, target, reduction='none') * focal_weight
- loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
- return loss
- @MODELS.register_module()
- class VarifocalLoss(nn.Module):
- def __init__(self,
- use_sigmoid: bool = True,
- alpha: float = 0.75,
- gamma: float = 2.0,
- iou_weighted: bool = True,
- reduction: str = 'mean',
- loss_weight: float = 1.0) -> None:
- """`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_
- Args:
- use_sigmoid (bool, optional): Whether the prediction is
- used for sigmoid or softmax. Defaults to True.
- alpha (float, optional): A balance factor for the negative part of
- Varifocal Loss, which is different from the alpha of Focal
- Loss. Defaults to 0.75.
- gamma (float, optional): The gamma for calculating the modulating
- factor. Defaults to 2.0.
- iou_weighted (bool, optional): Whether to weight the loss of the
- positive examples with the iou target. Defaults to True.
- reduction (str, optional): The method used to reduce the loss into
- a scalar. Defaults to 'mean'. Options are "none", "mean" and
- "sum".
- loss_weight (float, optional): Weight of loss. Defaults to 1.0.
- """
- super().__init__()
- assert use_sigmoid is True, \
- 'Only sigmoid varifocal loss supported now.'
- assert alpha >= 0.0
- self.use_sigmoid = use_sigmoid
- self.alpha = alpha
- self.gamma = gamma
- self.iou_weighted = iou_weighted
- self.reduction = reduction
- self.loss_weight = loss_weight
- def forward(self,
- pred: Tensor,
- target: Tensor,
- weight: Optional[Tensor] = None,
- avg_factor: Optional[int] = None,
- reduction_override: Optional[str] = None) -> Tensor:
- """Forward function.
- Args:
- pred (Tensor): The prediction with shape (N, C), C is the
- number of classes.
- target (Tensor): The learning target of the iou-aware
- classification score with shape (N, C), C is
- the number of classes.
- weight (Tensor, optional): The weight of loss for each
- prediction. Defaults to None.
- avg_factor (int, optional): Average factor that is used to average
- the loss. Defaults to None.
- reduction_override (str, optional): The reduction method used to
- override the original reduction method of the loss.
- Options are "none", "mean" and "sum".
- Returns:
- Tensor: The calculated loss
- """
- assert reduction_override in (None, 'none', 'mean', 'sum')
- reduction = (
- reduction_override if reduction_override else self.reduction)
- if self.use_sigmoid:
- loss_cls = self.loss_weight * varifocal_loss(
- pred,
- target,
- weight,
- alpha=self.alpha,
- gamma=self.gamma,
- iou_weighted=self.iou_weighted,
- reduction=reduction,
- avg_factor=avg_factor)
- else:
- raise NotImplementedError
- return loss_cls
|