focal_loss.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
  6. from mmdet.registry import MODELS
  7. from .utils import weight_reduce_loss
  8. # This method is only for debugging
  9. def py_sigmoid_focal_loss(pred,
  10. target,
  11. weight=None,
  12. gamma=2.0,
  13. alpha=0.25,
  14. reduction='mean',
  15. avg_factor=None):
  16. """PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
  17. Args:
  18. pred (torch.Tensor): The prediction with shape (N, C), C is the
  19. number of classes
  20. target (torch.Tensor): The learning label of the prediction.
  21. weight (torch.Tensor, optional): Sample-wise loss weight.
  22. gamma (float, optional): The gamma for calculating the modulating
  23. factor. Defaults to 2.0.
  24. alpha (float, optional): A balanced form for Focal Loss.
  25. Defaults to 0.25.
  26. reduction (str, optional): The method used to reduce the loss into
  27. a scalar. Defaults to 'mean'.
  28. avg_factor (int, optional): Average factor that is used to average
  29. the loss. Defaults to None.
  30. """
  31. pred_sigmoid = pred.sigmoid()
  32. target = target.type_as(pred)
  33. pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
  34. focal_weight = (alpha * target + (1 - alpha) *
  35. (1 - target)) * pt.pow(gamma)
  36. loss = F.binary_cross_entropy_with_logits(
  37. pred, target, reduction='none') * focal_weight
  38. if weight is not None:
  39. if weight.shape != loss.shape:
  40. if weight.size(0) == loss.size(0):
  41. # For most cases, weight is of shape (num_priors, ),
  42. # which means it does not have the second axis num_class
  43. weight = weight.view(-1, 1)
  44. else:
  45. # Sometimes, weight per anchor per class is also needed. e.g.
  46. # in FSAF. But it may be flattened of shape
  47. # (num_priors x num_class, ), while loss is still of shape
  48. # (num_priors, num_class).
  49. assert weight.numel() == loss.numel()
  50. weight = weight.view(loss.size(0), -1)
  51. assert weight.ndim == loss.ndim
  52. loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
  53. return loss
  54. def py_focal_loss_with_prob(pred,
  55. target,
  56. weight=None,
  57. gamma=2.0,
  58. alpha=0.25,
  59. reduction='mean',
  60. avg_factor=None):
  61. """PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
  62. Different from `py_sigmoid_focal_loss`, this function accepts probability
  63. as input.
  64. Args:
  65. pred (torch.Tensor): The prediction probability with shape (N, C),
  66. C is the number of classes.
  67. target (torch.Tensor): The learning label of the prediction.
  68. The target shape support (N,C) or (N,), (N,C) means one-hot form.
  69. weight (torch.Tensor, optional): Sample-wise loss weight.
  70. gamma (float, optional): The gamma for calculating the modulating
  71. factor. Defaults to 2.0.
  72. alpha (float, optional): A balanced form for Focal Loss.
  73. Defaults to 0.25.
  74. reduction (str, optional): The method used to reduce the loss into
  75. a scalar. Defaults to 'mean'.
  76. avg_factor (int, optional): Average factor that is used to average
  77. the loss. Defaults to None.
  78. """
  79. if pred.dim() != target.dim():
  80. num_classes = pred.size(1)
  81. target = F.one_hot(target, num_classes=num_classes + 1)
  82. target = target[:, :num_classes]
  83. target = target.type_as(pred)
  84. pt = (1 - pred) * target + pred * (1 - target)
  85. focal_weight = (alpha * target + (1 - alpha) *
  86. (1 - target)) * pt.pow(gamma)
  87. loss = F.binary_cross_entropy(
  88. pred, target, reduction='none') * focal_weight
  89. if weight is not None:
  90. if weight.shape != loss.shape:
  91. if weight.size(0) == loss.size(0):
  92. # For most cases, weight is of shape (num_priors, ),
  93. # which means it does not have the second axis num_class
  94. weight = weight.view(-1, 1)
  95. else:
  96. # Sometimes, weight per anchor per class is also needed. e.g.
  97. # in FSAF. But it may be flattened of shape
  98. # (num_priors x num_class, ), while loss is still of shape
  99. # (num_priors, num_class).
  100. assert weight.numel() == loss.numel()
  101. weight = weight.view(loss.size(0), -1)
  102. assert weight.ndim == loss.ndim
  103. loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
  104. return loss
  105. def sigmoid_focal_loss(pred,
  106. target,
  107. weight=None,
  108. gamma=2.0,
  109. alpha=0.25,
  110. reduction='mean',
  111. avg_factor=None):
  112. r"""A wrapper of cuda version `Focal Loss
  113. <https://arxiv.org/abs/1708.02002>`_.
  114. Args:
  115. pred (torch.Tensor): The prediction with shape (N, C), C is the number
  116. of classes.
  117. target (torch.Tensor): The learning label of the prediction.
  118. weight (torch.Tensor, optional): Sample-wise loss weight.
  119. gamma (float, optional): The gamma for calculating the modulating
  120. factor. Defaults to 2.0.
  121. alpha (float, optional): A balanced form for Focal Loss.
  122. Defaults to 0.25.
  123. reduction (str, optional): The method used to reduce the loss into
  124. a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
  125. avg_factor (int, optional): Average factor that is used to average
  126. the loss. Defaults to None.
  127. """
  128. # Function.apply does not accept keyword arguments, so the decorator
  129. # "weighted_loss" is not applicable
  130. loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma,
  131. alpha, None, 'none')
  132. if weight is not None:
  133. if weight.shape != loss.shape:
  134. if weight.size(0) == loss.size(0):
  135. # For most cases, weight is of shape (num_priors, ),
  136. # which means it does not have the second axis num_class
  137. weight = weight.view(-1, 1)
  138. else:
  139. # Sometimes, weight per anchor per class is also needed. e.g.
  140. # in FSAF. But it may be flattened of shape
  141. # (num_priors x num_class, ), while loss is still of shape
  142. # (num_priors, num_class).
  143. assert weight.numel() == loss.numel()
  144. weight = weight.view(loss.size(0), -1)
  145. assert weight.ndim == loss.ndim
  146. loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
  147. return loss
  148. @MODELS.register_module()
  149. class FocalLoss(nn.Module):
  150. def __init__(self,
  151. use_sigmoid=True,
  152. gamma=2.0,
  153. alpha=0.25,
  154. reduction='mean',
  155. loss_weight=1.0,
  156. activated=False):
  157. """`Focal Loss <https://arxiv.org/abs/1708.02002>`_
  158. Args:
  159. use_sigmoid (bool, optional): Whether to the prediction is
  160. used for sigmoid or softmax. Defaults to True.
  161. gamma (float, optional): The gamma for calculating the modulating
  162. factor. Defaults to 2.0.
  163. alpha (float, optional): A balanced form for Focal Loss.
  164. Defaults to 0.25.
  165. reduction (str, optional): The method used to reduce the loss into
  166. a scalar. Defaults to 'mean'. Options are "none", "mean" and
  167. "sum".
  168. loss_weight (float, optional): Weight of loss. Defaults to 1.0.
  169. activated (bool, optional): Whether the input is activated.
  170. If True, it means the input has been activated and can be
  171. treated as probabilities. Else, it should be treated as logits.
  172. Defaults to False.
  173. """
  174. super(FocalLoss, self).__init__()
  175. assert use_sigmoid is True, 'Only sigmoid focal loss supported now.'
  176. self.use_sigmoid = use_sigmoid
  177. self.gamma = gamma
  178. self.alpha = alpha
  179. self.reduction = reduction
  180. self.loss_weight = loss_weight
  181. self.activated = activated
  182. def forward(self,
  183. pred,
  184. target,
  185. weight=None,
  186. avg_factor=None,
  187. reduction_override=None):
  188. """Forward function.
  189. Args:
  190. pred (torch.Tensor): The prediction.
  191. target (torch.Tensor): The learning label of the prediction.
  192. The target shape support (N,C) or (N,), (N,C) means
  193. one-hot form.
  194. weight (torch.Tensor, optional): The weight of loss for each
  195. prediction. Defaults to None.
  196. avg_factor (int, optional): Average factor that is used to average
  197. the loss. Defaults to None.
  198. reduction_override (str, optional): The reduction method used to
  199. override the original reduction method of the loss.
  200. Options are "none", "mean" and "sum".
  201. Returns:
  202. torch.Tensor: The calculated loss
  203. """
  204. assert reduction_override in (None, 'none', 'mean', 'sum')
  205. reduction = (
  206. reduction_override if reduction_override else self.reduction)
  207. if self.use_sigmoid:
  208. if self.activated:
  209. calculate_loss_func = py_focal_loss_with_prob
  210. else:
  211. if pred.dim() == target.dim():
  212. # this means that target is already in One-Hot form.
  213. calculate_loss_func = py_sigmoid_focal_loss
  214. elif torch.cuda.is_available() and pred.is_cuda:
  215. calculate_loss_func = sigmoid_focal_loss
  216. else:
  217. num_classes = pred.size(1)
  218. target = F.one_hot(target, num_classes=num_classes + 1)
  219. target = target[:, :num_classes]
  220. calculate_loss_func = py_sigmoid_focal_loss
  221. loss_cls = self.loss_weight * calculate_loss_func(
  222. pred,
  223. target,
  224. weight,
  225. gamma=self.gamma,
  226. alpha=self.alpha,
  227. reduction=reduction,
  228. avg_factor=avg_factor)
  229. else:
  230. raise NotImplementedError
  231. return loss_cls