123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import functools
- from typing import Callable, Optional
- import torch
- import torch.nn.functional as F
- from torch import Tensor
- def reduce_loss(loss: Tensor, reduction: str) -> Tensor:
- """Reduce loss as specified.
- Args:
- loss (Tensor): Elementwise loss tensor.
- reduction (str): Options are "none", "mean" and "sum".
- Return:
- Tensor: Reduced loss tensor.
- """
- reduction_enum = F._Reduction.get_enum(reduction)
- # none: 0, elementwise_mean:1, sum: 2
- if reduction_enum == 0:
- return loss
- elif reduction_enum == 1:
- return loss.mean()
- elif reduction_enum == 2:
- return loss.sum()
- def weight_reduce_loss(loss: Tensor,
- weight: Optional[Tensor] = None,
- reduction: str = 'mean',
- avg_factor: Optional[float] = None) -> Tensor:
- """Apply element-wise weight and reduce loss.
- Args:
- loss (Tensor): Element-wise loss.
- weight (Optional[Tensor], optional): Element-wise weights.
- Defaults to None.
- reduction (str, optional): Same as built-in losses of PyTorch.
- Defaults to 'mean'.
- avg_factor (Optional[float], optional): Average factor when
- computing the mean of losses. Defaults to None.
- Returns:
- Tensor: Processed loss values.
- """
- # if weight is specified, apply element-wise weight
- if weight is not None:
- loss = loss * weight
- # if avg_factor is not specified, just reduce the loss
- if avg_factor is None:
- loss = reduce_loss(loss, reduction)
- else:
- # if reduction is mean, then average the loss by avg_factor
- if reduction == 'mean':
- # Avoid causing ZeroDivisionError when avg_factor is 0.0,
- # i.e., all labels of an image belong to ignore index.
- eps = torch.finfo(torch.float32).eps
- loss = loss.sum() / (avg_factor + eps)
- # if reduction is 'none', then do nothing, otherwise raise an error
- elif reduction != 'none':
- raise ValueError('avg_factor can not be used with reduction="sum"')
- return loss
- def weighted_loss(loss_func: Callable) -> Callable:
- """Create a weighted version of a given loss function.
- To use this decorator, the loss function must have the signature like
- `loss_func(pred, target, **kwargs)`. The function only needs to compute
- element-wise loss without any reduction. This decorator will add weight
- and reduction arguments to the function. The decorated function will have
- the signature like `loss_func(pred, target, weight=None, reduction='mean',
- avg_factor=None, **kwargs)`.
- :Example:
- >>> import torch
- >>> @weighted_loss
- >>> def l1_loss(pred, target):
- >>> return (pred - target).abs()
- >>> pred = torch.Tensor([0, 2, 3])
- >>> target = torch.Tensor([1, 1, 1])
- >>> weight = torch.Tensor([1, 0, 1])
- >>> l1_loss(pred, target)
- tensor(1.3333)
- >>> l1_loss(pred, target, weight)
- tensor(1.)
- >>> l1_loss(pred, target, reduction='none')
- tensor([1., 1., 2.])
- >>> l1_loss(pred, target, weight, avg_factor=2)
- tensor(1.5000)
- """
- @functools.wraps(loss_func)
- def wrapper(pred: Tensor,
- target: Tensor,
- weight: Optional[Tensor] = None,
- reduction: str = 'mean',
- avg_factor: Optional[int] = None,
- **kwargs) -> Tensor:
- """
- Args:
- pred (Tensor): The prediction.
- target (Tensor): Target bboxes.
- weight (Optional[Tensor], optional): The weight of loss for each
- prediction. Defaults to None.
- reduction (str, optional): Options are "none", "mean" and "sum".
- Defaults to 'mean'.
- avg_factor (Optional[int], optional): Average factor that is used
- to average the loss. Defaults to None.
- Returns:
- Tensor: Loss tensor.
- """
- # get element-wise loss
- loss = loss_func(pred, target, **kwargs)
- loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
- return loss
- return wrapper
|