123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Optional
- import torch
- import torch.nn as nn
- from torch import Tensor
- from mmdet.registry import MODELS
- from .utils import weighted_loss
- @weighted_loss
- def smooth_l1_loss(pred: Tensor, target: Tensor, beta: float = 1.0) -> Tensor:
- """Smooth L1 loss.
- Args:
- pred (Tensor): The prediction.
- target (Tensor): The learning target of the prediction.
- beta (float, optional): The threshold in the piecewise function.
- Defaults to 1.0.
- Returns:
- Tensor: Calculated loss
- """
- assert beta > 0
- if target.numel() == 0:
- return pred.sum() * 0
- assert pred.size() == target.size()
- diff = torch.abs(pred - target)
- loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
- diff - 0.5 * beta)
- return loss
- @weighted_loss
- def l1_loss(pred: Tensor, target: Tensor) -> Tensor:
- """L1 loss.
- Args:
- pred (Tensor): The prediction.
- target (Tensor): The learning target of the prediction.
- Returns:
- Tensor: Calculated loss
- """
- if target.numel() == 0:
- return pred.sum() * 0
- assert pred.size() == target.size()
- loss = torch.abs(pred - target)
- return loss
- @MODELS.register_module()
- class SmoothL1Loss(nn.Module):
- """Smooth L1 loss.
- Args:
- beta (float, optional): The threshold in the piecewise function.
- Defaults to 1.0.
- reduction (str, optional): The method to reduce the loss.
- Options are "none", "mean" and "sum". Defaults to "mean".
- loss_weight (float, optional): The weight of loss.
- """
- def __init__(self,
- beta: float = 1.0,
- reduction: str = 'mean',
- loss_weight: float = 1.0) -> None:
- super().__init__()
- self.beta = beta
- self.reduction = reduction
- self.loss_weight = loss_weight
- def forward(self,
- pred: Tensor,
- target: Tensor,
- weight: Optional[Tensor] = None,
- avg_factor: Optional[int] = None,
- reduction_override: Optional[str] = None,
- **kwargs) -> Tensor:
- """Forward function.
- Args:
- pred (Tensor): The prediction.
- target (Tensor): The learning target of the prediction.
- weight (Tensor, optional): The weight of loss for each
- prediction. Defaults to None.
- avg_factor (int, optional): Average factor that is used to average
- the loss. Defaults to None.
- reduction_override (str, optional): The reduction method used to
- override the original reduction method of the loss.
- Defaults to None.
- Returns:
- Tensor: Calculated loss
- """
- assert reduction_override in (None, 'none', 'mean', 'sum')
- reduction = (
- reduction_override if reduction_override else self.reduction)
- loss_bbox = self.loss_weight * smooth_l1_loss(
- pred,
- target,
- weight,
- beta=self.beta,
- reduction=reduction,
- avg_factor=avg_factor,
- **kwargs)
- return loss_bbox
- @MODELS.register_module()
- class L1Loss(nn.Module):
- """L1 loss.
- Args:
- reduction (str, optional): The method to reduce the loss.
- Options are "none", "mean" and "sum".
- loss_weight (float, optional): The weight of loss.
- """
- def __init__(self,
- reduction: str = 'mean',
- loss_weight: float = 1.0) -> None:
- super().__init__()
- self.reduction = reduction
- self.loss_weight = loss_weight
- def forward(self,
- pred: Tensor,
- target: Tensor,
- weight: Optional[Tensor] = None,
- avg_factor: Optional[int] = None,
- reduction_override: Optional[str] = None) -> Tensor:
- """Forward function.
- Args:
- pred (Tensor): The prediction.
- target (Tensor): The learning target of the prediction.
- weight (Tensor, optional): The weight of loss for each
- prediction. Defaults to None.
- avg_factor (int, optional): Average factor that is used to average
- the loss. Defaults to None.
- reduction_override (str, optional): The reduction method used to
- override the original reduction method of the loss.
- Defaults to None.
- Returns:
- Tensor: Calculated loss
- """
- assert reduction_override in (None, 'none', 'mean', 'sum')
- reduction = (
- reduction_override if reduction_override else self.reduction)
- loss_bbox = self.loss_weight * l1_loss(
- pred, target, weight, reduction=reduction, avg_factor=avg_factor)
- return loss_bbox
|