cross_entropy_loss.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmdet.registry import MODELS
  7. from .utils import weight_reduce_loss
  8. def cross_entropy(pred,
  9. label,
  10. weight=None,
  11. reduction='mean',
  12. avg_factor=None,
  13. class_weight=None,
  14. ignore_index=-100,
  15. avg_non_ignore=False):
  16. """Calculate the CrossEntropy loss.
  17. Args:
  18. pred (torch.Tensor): The prediction with shape (N, C), C is the number
  19. of classes.
  20. label (torch.Tensor): The learning label of the prediction.
  21. weight (torch.Tensor, optional): Sample-wise loss weight.
  22. reduction (str, optional): The method used to reduce the loss.
  23. avg_factor (int, optional): Average factor that is used to average
  24. the loss. Defaults to None.
  25. class_weight (list[float], optional): The weight for each class.
  26. ignore_index (int | None): The label index to be ignored.
  27. If None, it will be set to default value. Default: -100.
  28. avg_non_ignore (bool): The flag decides to whether the loss is
  29. only averaged over non-ignored targets. Default: False.
  30. Returns:
  31. torch.Tensor: The calculated loss
  32. """
  33. # The default value of ignore_index is the same as F.cross_entropy
  34. ignore_index = -100 if ignore_index is None else ignore_index
  35. # element-wise losses
  36. loss = F.cross_entropy(
  37. pred,
  38. label,
  39. weight=class_weight,
  40. reduction='none',
  41. ignore_index=ignore_index)
  42. # average loss over non-ignored elements
  43. # pytorch's official cross_entropy average loss over non-ignored elements
  44. # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
  45. if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
  46. avg_factor = label.numel() - (label == ignore_index).sum().item()
  47. # apply weights and do the reduction
  48. if weight is not None:
  49. weight = weight.float()
  50. loss = weight_reduce_loss(
  51. loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
  52. return loss
  53. def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index):
  54. """Expand onehot labels to match the size of prediction."""
  55. bin_labels = labels.new_full((labels.size(0), label_channels), 0)
  56. valid_mask = (labels >= 0) & (labels != ignore_index)
  57. inds = torch.nonzero(
  58. valid_mask & (labels < label_channels), as_tuple=False)
  59. if inds.numel() > 0:
  60. bin_labels[inds, labels[inds]] = 1
  61. valid_mask = valid_mask.view(-1, 1).expand(labels.size(0),
  62. label_channels).float()
  63. if label_weights is None:
  64. bin_label_weights = valid_mask
  65. else:
  66. bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels)
  67. bin_label_weights *= valid_mask
  68. return bin_labels, bin_label_weights, valid_mask
  69. def binary_cross_entropy(pred,
  70. label,
  71. weight=None,
  72. reduction='mean',
  73. avg_factor=None,
  74. class_weight=None,
  75. ignore_index=-100,
  76. avg_non_ignore=False):
  77. """Calculate the binary CrossEntropy loss.
  78. Args:
  79. pred (torch.Tensor): The prediction with shape (N, 1) or (N, ).
  80. When the shape of pred is (N, 1), label will be expanded to
  81. one-hot format, and when the shape of pred is (N, ), label
  82. will not be expanded to one-hot format.
  83. label (torch.Tensor): The learning label of the prediction,
  84. with shape (N, ).
  85. weight (torch.Tensor, optional): Sample-wise loss weight.
  86. reduction (str, optional): The method used to reduce the loss.
  87. Options are "none", "mean" and "sum".
  88. avg_factor (int, optional): Average factor that is used to average
  89. the loss. Defaults to None.
  90. class_weight (list[float], optional): The weight for each class.
  91. ignore_index (int | None): The label index to be ignored.
  92. If None, it will be set to default value. Default: -100.
  93. avg_non_ignore (bool): The flag decides to whether the loss is
  94. only averaged over non-ignored targets. Default: False.
  95. Returns:
  96. torch.Tensor: The calculated loss.
  97. """
  98. # The default value of ignore_index is the same as F.cross_entropy
  99. ignore_index = -100 if ignore_index is None else ignore_index
  100. if pred.dim() != label.dim():
  101. label, weight, valid_mask = _expand_onehot_labels(
  102. label, weight, pred.size(-1), ignore_index)
  103. else:
  104. # should mask out the ignored elements
  105. valid_mask = ((label >= 0) & (label != ignore_index)).float()
  106. if weight is not None:
  107. # The inplace writing method will have a mismatched broadcast
  108. # shape error if the weight and valid_mask dimensions
  109. # are inconsistent such as (B,N,1) and (B,N,C).
  110. weight = weight * valid_mask
  111. else:
  112. weight = valid_mask
  113. # average loss over non-ignored elements
  114. if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
  115. avg_factor = valid_mask.sum().item()
  116. # weighted element-wise losses
  117. weight = weight.float()
  118. loss = F.binary_cross_entropy_with_logits(
  119. pred, label.float(), pos_weight=class_weight, reduction='none')
  120. # do the reduction for the weighted loss
  121. loss = weight_reduce_loss(
  122. loss, weight, reduction=reduction, avg_factor=avg_factor)
  123. return loss
  124. def mask_cross_entropy(pred,
  125. target,
  126. label,
  127. reduction='mean',
  128. avg_factor=None,
  129. class_weight=None,
  130. ignore_index=None,
  131. **kwargs):
  132. """Calculate the CrossEntropy loss for masks.
  133. Args:
  134. pred (torch.Tensor): The prediction with shape (N, C, *), C is the
  135. number of classes. The trailing * indicates arbitrary shape.
  136. target (torch.Tensor): The learning label of the prediction.
  137. label (torch.Tensor): ``label`` indicates the class label of the mask
  138. corresponding object. This will be used to select the mask in the
  139. of the class which the object belongs to when the mask prediction
  140. if not class-agnostic.
  141. reduction (str, optional): The method used to reduce the loss.
  142. Options are "none", "mean" and "sum".
  143. avg_factor (int, optional): Average factor that is used to average
  144. the loss. Defaults to None.
  145. class_weight (list[float], optional): The weight for each class.
  146. ignore_index (None): Placeholder, to be consistent with other loss.
  147. Default: None.
  148. Returns:
  149. torch.Tensor: The calculated loss
  150. Example:
  151. >>> N, C = 3, 11
  152. >>> H, W = 2, 2
  153. >>> pred = torch.randn(N, C, H, W) * 1000
  154. >>> target = torch.rand(N, H, W)
  155. >>> label = torch.randint(0, C, size=(N,))
  156. >>> reduction = 'mean'
  157. >>> avg_factor = None
  158. >>> class_weights = None
  159. >>> loss = mask_cross_entropy(pred, target, label, reduction,
  160. >>> avg_factor, class_weights)
  161. >>> assert loss.shape == (1,)
  162. """
  163. assert ignore_index is None, 'BCE loss does not support ignore_index'
  164. # TODO: handle these two reserved arguments
  165. assert reduction == 'mean' and avg_factor is None
  166. num_rois = pred.size()[0]
  167. inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
  168. pred_slice = pred[inds, label].squeeze(1)
  169. return F.binary_cross_entropy_with_logits(
  170. pred_slice, target, weight=class_weight, reduction='mean')[None]
  171. @MODELS.register_module()
  172. class CrossEntropyLoss(nn.Module):
  173. def __init__(self,
  174. use_sigmoid=False,
  175. use_mask=False,
  176. reduction='mean',
  177. class_weight=None,
  178. ignore_index=None,
  179. loss_weight=1.0,
  180. avg_non_ignore=False):
  181. """CrossEntropyLoss.
  182. Args:
  183. use_sigmoid (bool, optional): Whether the prediction uses sigmoid
  184. of softmax. Defaults to False.
  185. use_mask (bool, optional): Whether to use mask cross entropy loss.
  186. Defaults to False.
  187. reduction (str, optional): . Defaults to 'mean'.
  188. Options are "none", "mean" and "sum".
  189. class_weight (list[float], optional): Weight of each class.
  190. Defaults to None.
  191. ignore_index (int | None): The label index to be ignored.
  192. Defaults to None.
  193. loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
  194. avg_non_ignore (bool): The flag decides to whether the loss is
  195. only averaged over non-ignored targets. Default: False.
  196. """
  197. super(CrossEntropyLoss, self).__init__()
  198. assert (use_sigmoid is False) or (use_mask is False)
  199. self.use_sigmoid = use_sigmoid
  200. self.use_mask = use_mask
  201. self.reduction = reduction
  202. self.loss_weight = loss_weight
  203. self.class_weight = class_weight
  204. self.ignore_index = ignore_index
  205. self.avg_non_ignore = avg_non_ignore
  206. if ((ignore_index is not None) and not self.avg_non_ignore
  207. and self.reduction == 'mean'):
  208. warnings.warn(
  209. 'Default ``avg_non_ignore`` is False, if you would like to '
  210. 'ignore the certain label and average loss over non-ignore '
  211. 'labels, which is the same with PyTorch official '
  212. 'cross_entropy, set ``avg_non_ignore=True``.')
  213. if self.use_sigmoid:
  214. self.cls_criterion = binary_cross_entropy
  215. elif self.use_mask:
  216. self.cls_criterion = mask_cross_entropy
  217. else:
  218. self.cls_criterion = cross_entropy
  219. def extra_repr(self):
  220. """Extra repr."""
  221. s = f'avg_non_ignore={self.avg_non_ignore}'
  222. return s
  223. def forward(self,
  224. cls_score,
  225. label,
  226. weight=None,
  227. avg_factor=None,
  228. reduction_override=None,
  229. ignore_index=None,
  230. **kwargs):
  231. """Forward function.
  232. Args:
  233. cls_score (torch.Tensor): The prediction.
  234. label (torch.Tensor): The learning label of the prediction.
  235. weight (torch.Tensor, optional): Sample-wise loss weight.
  236. avg_factor (int, optional): Average factor that is used to average
  237. the loss. Defaults to None.
  238. reduction_override (str, optional): The method used to reduce the
  239. loss. Options are "none", "mean" and "sum".
  240. ignore_index (int | None): The label index to be ignored.
  241. If not None, it will override the default value. Default: None.
  242. Returns:
  243. torch.Tensor: The calculated loss.
  244. """
  245. assert reduction_override in (None, 'none', 'mean', 'sum')
  246. reduction = (
  247. reduction_override if reduction_override else self.reduction)
  248. if ignore_index is None:
  249. ignore_index = self.ignore_index
  250. if self.class_weight is not None:
  251. class_weight = cls_score.new_tensor(
  252. self.class_weight, device=cls_score.device)
  253. else:
  254. class_weight = None
  255. loss_cls = self.loss_weight * self.cls_criterion(
  256. cls_score,
  257. label,
  258. weight,
  259. class_weight=class_weight,
  260. reduction=reduction,
  261. avg_factor=avg_factor,
  262. ignore_index=ignore_index,
  263. avg_non_ignore=self.avg_non_ignore,
  264. **kwargs)
  265. return loss_cls