utils.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import functools
  3. from typing import Callable, Optional
  4. import torch
  5. import torch.nn.functional as F
  6. from torch import Tensor
  7. def reduce_loss(loss: Tensor, reduction: str) -> Tensor:
  8. """Reduce loss as specified.
  9. Args:
  10. loss (Tensor): Elementwise loss tensor.
  11. reduction (str): Options are "none", "mean" and "sum".
  12. Return:
  13. Tensor: Reduced loss tensor.
  14. """
  15. reduction_enum = F._Reduction.get_enum(reduction)
  16. # none: 0, elementwise_mean:1, sum: 2
  17. if reduction_enum == 0:
  18. return loss
  19. elif reduction_enum == 1:
  20. return loss.mean()
  21. elif reduction_enum == 2:
  22. return loss.sum()
  23. def weight_reduce_loss(loss: Tensor,
  24. weight: Optional[Tensor] = None,
  25. reduction: str = 'mean',
  26. avg_factor: Optional[float] = None) -> Tensor:
  27. """Apply element-wise weight and reduce loss.
  28. Args:
  29. loss (Tensor): Element-wise loss.
  30. weight (Optional[Tensor], optional): Element-wise weights.
  31. Defaults to None.
  32. reduction (str, optional): Same as built-in losses of PyTorch.
  33. Defaults to 'mean'.
  34. avg_factor (Optional[float], optional): Average factor when
  35. computing the mean of losses. Defaults to None.
  36. Returns:
  37. Tensor: Processed loss values.
  38. """
  39. # if weight is specified, apply element-wise weight
  40. if weight is not None:
  41. loss = loss * weight
  42. # if avg_factor is not specified, just reduce the loss
  43. if avg_factor is None:
  44. loss = reduce_loss(loss, reduction)
  45. else:
  46. # if reduction is mean, then average the loss by avg_factor
  47. if reduction == 'mean':
  48. # Avoid causing ZeroDivisionError when avg_factor is 0.0,
  49. # i.e., all labels of an image belong to ignore index.
  50. eps = torch.finfo(torch.float32).eps
  51. loss = loss.sum() / (avg_factor + eps)
  52. # if reduction is 'none', then do nothing, otherwise raise an error
  53. elif reduction != 'none':
  54. raise ValueError('avg_factor can not be used with reduction="sum"')
  55. return loss
  56. def weighted_loss(loss_func: Callable) -> Callable:
  57. """Create a weighted version of a given loss function.
  58. To use this decorator, the loss function must have the signature like
  59. `loss_func(pred, target, **kwargs)`. The function only needs to compute
  60. element-wise loss without any reduction. This decorator will add weight
  61. and reduction arguments to the function. The decorated function will have
  62. the signature like `loss_func(pred, target, weight=None, reduction='mean',
  63. avg_factor=None, **kwargs)`.
  64. :Example:
  65. >>> import torch
  66. >>> @weighted_loss
  67. >>> def l1_loss(pred, target):
  68. >>> return (pred - target).abs()
  69. >>> pred = torch.Tensor([0, 2, 3])
  70. >>> target = torch.Tensor([1, 1, 1])
  71. >>> weight = torch.Tensor([1, 0, 1])
  72. >>> l1_loss(pred, target)
  73. tensor(1.3333)
  74. >>> l1_loss(pred, target, weight)
  75. tensor(1.)
  76. >>> l1_loss(pred, target, reduction='none')
  77. tensor([1., 1., 2.])
  78. >>> l1_loss(pred, target, weight, avg_factor=2)
  79. tensor(1.5000)
  80. """
  81. @functools.wraps(loss_func)
  82. def wrapper(pred: Tensor,
  83. target: Tensor,
  84. weight: Optional[Tensor] = None,
  85. reduction: str = 'mean',
  86. avg_factor: Optional[int] = None,
  87. **kwargs) -> Tensor:
  88. """
  89. Args:
  90. pred (Tensor): The prediction.
  91. target (Tensor): Target bboxes.
  92. weight (Optional[Tensor], optional): The weight of loss for each
  93. prediction. Defaults to None.
  94. reduction (str, optional): Options are "none", "mean" and "sum".
  95. Defaults to 'mean'.
  96. avg_factor (Optional[int], optional): Average factor that is used
  97. to average the loss. Defaults to None.
  98. Returns:
  99. Tensor: Loss tensor.
  100. """
  101. # get element-wise loss
  102. loss = loss_func(pred, target, **kwargs)
  103. loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
  104. return loss
  105. return wrapper