mse_loss.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch import Tensor
  6. from mmdet.registry import MODELS
  7. from .utils import weighted_loss
  8. @weighted_loss
  9. def mse_loss(pred: Tensor, target: Tensor) -> Tensor:
  10. """A Wrapper of MSE loss.
  11. Args:
  12. pred (Tensor): The prediction.
  13. target (Tensor): The learning target of the prediction.
  14. Returns:
  15. Tensor: loss Tensor
  16. """
  17. return F.mse_loss(pred, target, reduction='none')
  18. @MODELS.register_module()
  19. class MSELoss(nn.Module):
  20. """MSELoss.
  21. Args:
  22. reduction (str, optional): The method that reduces the loss to a
  23. scalar. Options are "none", "mean" and "sum".
  24. loss_weight (float, optional): The weight of the loss. Defaults to 1.0
  25. """
  26. def __init__(self,
  27. reduction: str = 'mean',
  28. loss_weight: float = 1.0) -> None:
  29. super().__init__()
  30. self.reduction = reduction
  31. self.loss_weight = loss_weight
  32. def forward(self,
  33. pred: Tensor,
  34. target: Tensor,
  35. weight: Optional[Tensor] = None,
  36. avg_factor: Optional[int] = None,
  37. reduction_override: Optional[str] = None) -> Tensor:
  38. """Forward function of loss.
  39. Args:
  40. pred (Tensor): The prediction.
  41. target (Tensor): The learning target of the prediction.
  42. weight (Tensor, optional): Weight of the loss for each
  43. prediction. Defaults to None.
  44. avg_factor (int, optional): Average factor that is used to average
  45. the loss. Defaults to None.
  46. reduction_override (str, optional): The reduction method used to
  47. override the original reduction method of the loss.
  48. Defaults to None.
  49. Returns:
  50. Tensor: The calculated loss.
  51. """
  52. assert reduction_override in (None, 'none', 'mean', 'sum')
  53. reduction = (
  54. reduction_override if reduction_override else self.reduction)
  55. loss = self.loss_weight * mse_loss(
  56. pred, target, weight, reduction=reduction, avg_factor=avg_factor)
  57. return loss