seesaw_loss.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, Optional, Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from torch import Tensor
  7. from mmdet.registry import MODELS
  8. from .accuracy import accuracy
  9. from .cross_entropy_loss import cross_entropy
  10. from .utils import weight_reduce_loss
  11. def seesaw_ce_loss(cls_score: Tensor,
  12. labels: Tensor,
  13. label_weights: Tensor,
  14. cum_samples: Tensor,
  15. num_classes: int,
  16. p: float,
  17. q: float,
  18. eps: float,
  19. reduction: str = 'mean',
  20. avg_factor: Optional[int] = None) -> Tensor:
  21. """Calculate the Seesaw CrossEntropy loss.
  22. Args:
  23. cls_score (Tensor): The prediction with shape (N, C),
  24. C is the number of classes.
  25. labels (Tensor): The learning label of the prediction.
  26. label_weights (Tensor): Sample-wise loss weight.
  27. cum_samples (Tensor): Cumulative samples for each category.
  28. num_classes (int): The number of classes.
  29. p (float): The ``p`` in the mitigation factor.
  30. q (float): The ``q`` in the compenstation factor.
  31. eps (float): The minimal value of divisor to smooth
  32. the computation of compensation factor
  33. reduction (str, optional): The method used to reduce the loss.
  34. avg_factor (int, optional): Average factor that is used to average
  35. the loss. Defaults to None.
  36. Returns:
  37. Tensor: The calculated loss
  38. """
  39. assert cls_score.size(-1) == num_classes
  40. assert len(cum_samples) == num_classes
  41. onehot_labels = F.one_hot(labels, num_classes)
  42. seesaw_weights = cls_score.new_ones(onehot_labels.size())
  43. # mitigation factor
  44. if p > 0:
  45. sample_ratio_matrix = cum_samples[None, :].clamp(
  46. min=1) / cum_samples[:, None].clamp(min=1)
  47. index = (sample_ratio_matrix < 1.0).float()
  48. sample_weights = sample_ratio_matrix.pow(p) * index + (1 - index)
  49. mitigation_factor = sample_weights[labels.long(), :]
  50. seesaw_weights = seesaw_weights * mitigation_factor
  51. # compensation factor
  52. if q > 0:
  53. scores = F.softmax(cls_score.detach(), dim=1)
  54. self_scores = scores[
  55. torch.arange(0, len(scores)).to(scores.device).long(),
  56. labels.long()]
  57. score_matrix = scores / self_scores[:, None].clamp(min=eps)
  58. index = (score_matrix > 1.0).float()
  59. compensation_factor = score_matrix.pow(q) * index + (1 - index)
  60. seesaw_weights = seesaw_weights * compensation_factor
  61. cls_score = cls_score + (seesaw_weights.log() * (1 - onehot_labels))
  62. loss = F.cross_entropy(cls_score, labels, weight=None, reduction='none')
  63. if label_weights is not None:
  64. label_weights = label_weights.float()
  65. loss = weight_reduce_loss(
  66. loss, weight=label_weights, reduction=reduction, avg_factor=avg_factor)
  67. return loss
  68. @MODELS.register_module()
  69. class SeesawLoss(nn.Module):
  70. """
  71. Seesaw Loss for Long-Tailed Instance Segmentation (CVPR 2021)
  72. arXiv: https://arxiv.org/abs/2008.10032
  73. Args:
  74. use_sigmoid (bool, optional): Whether the prediction uses sigmoid
  75. of softmax. Only False is supported.
  76. p (float, optional): The ``p`` in the mitigation factor.
  77. Defaults to 0.8.
  78. q (float, optional): The ``q`` in the compenstation factor.
  79. Defaults to 2.0.
  80. num_classes (int, optional): The number of classes.
  81. Default to 1203 for LVIS v1 dataset.
  82. eps (float, optional): The minimal value of divisor to smooth
  83. the computation of compensation factor
  84. reduction (str, optional): The method that reduces the loss to a
  85. scalar. Options are "none", "mean" and "sum".
  86. loss_weight (float, optional): The weight of the loss. Defaults to 1.0
  87. return_dict (bool, optional): Whether return the losses as a dict.
  88. Default to True.
  89. """
  90. def __init__(self,
  91. use_sigmoid: bool = False,
  92. p: float = 0.8,
  93. q: float = 2.0,
  94. num_classes: int = 1203,
  95. eps: float = 1e-2,
  96. reduction: str = 'mean',
  97. loss_weight: float = 1.0,
  98. return_dict: bool = True) -> None:
  99. super().__init__()
  100. assert not use_sigmoid
  101. self.use_sigmoid = False
  102. self.p = p
  103. self.q = q
  104. self.num_classes = num_classes
  105. self.eps = eps
  106. self.reduction = reduction
  107. self.loss_weight = loss_weight
  108. self.return_dict = return_dict
  109. # 0 for pos, 1 for neg
  110. self.cls_criterion = seesaw_ce_loss
  111. # cumulative samples for each category
  112. self.register_buffer(
  113. 'cum_samples',
  114. torch.zeros(self.num_classes + 1, dtype=torch.float))
  115. # custom output channels of the classifier
  116. self.custom_cls_channels = True
  117. # custom activation of cls_score
  118. self.custom_activation = True
  119. # custom accuracy of the classsifier
  120. self.custom_accuracy = True
  121. def _split_cls_score(self, cls_score: Tensor) -> Tuple[Tensor, Tensor]:
  122. """split cls_score.
  123. Args:
  124. cls_score (Tensor): The prediction with shape (N, C + 2).
  125. Returns:
  126. Tuple[Tensor, Tensor]: The score for classes and objectness,
  127. respectively
  128. """
  129. # split cls_score to cls_score_classes and cls_score_objectness
  130. assert cls_score.size(-1) == self.num_classes + 2
  131. cls_score_classes = cls_score[..., :-2]
  132. cls_score_objectness = cls_score[..., -2:]
  133. return cls_score_classes, cls_score_objectness
  134. def get_cls_channels(self, num_classes: int) -> int:
  135. """Get custom classification channels.
  136. Args:
  137. num_classes (int): The number of classes.
  138. Returns:
  139. int: The custom classification channels.
  140. """
  141. assert num_classes == self.num_classes
  142. return num_classes + 2
  143. def get_activation(self, cls_score: Tensor) -> Tensor:
  144. """Get custom activation of cls_score.
  145. Args:
  146. cls_score (Tensor): The prediction with shape (N, C + 2).
  147. Returns:
  148. Tensor: The custom activation of cls_score with shape
  149. (N, C + 1).
  150. """
  151. cls_score_classes, cls_score_objectness = self._split_cls_score(
  152. cls_score)
  153. score_classes = F.softmax(cls_score_classes, dim=-1)
  154. score_objectness = F.softmax(cls_score_objectness, dim=-1)
  155. score_pos = score_objectness[..., [0]]
  156. score_neg = score_objectness[..., [1]]
  157. score_classes = score_classes * score_pos
  158. scores = torch.cat([score_classes, score_neg], dim=-1)
  159. return scores
  160. def get_accuracy(self, cls_score: Tensor,
  161. labels: Tensor) -> Dict[str, Tensor]:
  162. """Get custom accuracy w.r.t. cls_score and labels.
  163. Args:
  164. cls_score (Tensor): The prediction with shape (N, C + 2).
  165. labels (Tensor): The learning label of the prediction.
  166. Returns:
  167. Dict [str, Tensor]: The accuracy for objectness and classes,
  168. respectively.
  169. """
  170. pos_inds = labels < self.num_classes
  171. obj_labels = (labels == self.num_classes).long()
  172. cls_score_classes, cls_score_objectness = self._split_cls_score(
  173. cls_score)
  174. acc_objectness = accuracy(cls_score_objectness, obj_labels)
  175. acc_classes = accuracy(cls_score_classes[pos_inds], labels[pos_inds])
  176. acc = dict()
  177. acc['acc_objectness'] = acc_objectness
  178. acc['acc_classes'] = acc_classes
  179. return acc
  180. def forward(
  181. self,
  182. cls_score: Tensor,
  183. labels: Tensor,
  184. label_weights: Optional[Tensor] = None,
  185. avg_factor: Optional[int] = None,
  186. reduction_override: Optional[str] = None
  187. ) -> Union[Tensor, Dict[str, Tensor]]:
  188. """Forward function.
  189. Args:
  190. cls_score (Tensor): The prediction with shape (N, C + 2).
  191. labels (Tensor): The learning label of the prediction.
  192. label_weights (Tensor, optional): Sample-wise loss weight.
  193. avg_factor (int, optional): Average factor that is used to average
  194. the loss. Defaults to None.
  195. reduction (str, optional): The method used to reduce the loss.
  196. Options are "none", "mean" and "sum".
  197. Returns:
  198. Tensor | Dict [str, Tensor]:
  199. if return_dict == False: The calculated loss |
  200. if return_dict == True: The dict of calculated losses
  201. for objectness and classes, respectively.
  202. """
  203. assert reduction_override in (None, 'none', 'mean', 'sum')
  204. reduction = (
  205. reduction_override if reduction_override else self.reduction)
  206. assert cls_score.size(-1) == self.num_classes + 2
  207. pos_inds = labels < self.num_classes
  208. # 0 for pos, 1 for neg
  209. obj_labels = (labels == self.num_classes).long()
  210. # accumulate the samples for each category
  211. unique_labels = labels.unique()
  212. for u_l in unique_labels:
  213. inds_ = labels == u_l.item()
  214. self.cum_samples[u_l] += inds_.sum()
  215. if label_weights is not None:
  216. label_weights = label_weights.float()
  217. else:
  218. label_weights = labels.new_ones(labels.size(), dtype=torch.float)
  219. cls_score_classes, cls_score_objectness = self._split_cls_score(
  220. cls_score)
  221. # calculate loss_cls_classes (only need pos samples)
  222. if pos_inds.sum() > 0:
  223. loss_cls_classes = self.loss_weight * self.cls_criterion(
  224. cls_score_classes[pos_inds], labels[pos_inds],
  225. label_weights[pos_inds], self.cum_samples[:self.num_classes],
  226. self.num_classes, self.p, self.q, self.eps, reduction,
  227. avg_factor)
  228. else:
  229. loss_cls_classes = cls_score_classes[pos_inds].sum()
  230. # calculate loss_cls_objectness
  231. loss_cls_objectness = self.loss_weight * cross_entropy(
  232. cls_score_objectness, obj_labels, label_weights, reduction,
  233. avg_factor)
  234. if self.return_dict:
  235. loss_cls = dict()
  236. loss_cls['loss_cls_objectness'] = loss_cls_objectness
  237. loss_cls['loss_cls_classes'] = loss_cls_classes
  238. else:
  239. loss_cls = loss_cls_classes + loss_cls_objectness
  240. return loss_cls