123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Dict, Optional, Tuple, Union
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch import Tensor
- from mmdet.registry import MODELS
- from .accuracy import accuracy
- from .cross_entropy_loss import cross_entropy
- from .utils import weight_reduce_loss
- def seesaw_ce_loss(cls_score: Tensor,
- labels: Tensor,
- label_weights: Tensor,
- cum_samples: Tensor,
- num_classes: int,
- p: float,
- q: float,
- eps: float,
- reduction: str = 'mean',
- avg_factor: Optional[int] = None) -> Tensor:
- """Calculate the Seesaw CrossEntropy loss.
- Args:
- cls_score (Tensor): The prediction with shape (N, C),
- C is the number of classes.
- labels (Tensor): The learning label of the prediction.
- label_weights (Tensor): Sample-wise loss weight.
- cum_samples (Tensor): Cumulative samples for each category.
- num_classes (int): The number of classes.
- p (float): The ``p`` in the mitigation factor.
- q (float): The ``q`` in the compenstation factor.
- eps (float): The minimal value of divisor to smooth
- the computation of compensation factor
- reduction (str, optional): The method used to reduce the loss.
- avg_factor (int, optional): Average factor that is used to average
- the loss. Defaults to None.
- Returns:
- Tensor: The calculated loss
- """
- assert cls_score.size(-1) == num_classes
- assert len(cum_samples) == num_classes
- onehot_labels = F.one_hot(labels, num_classes)
- seesaw_weights = cls_score.new_ones(onehot_labels.size())
- # mitigation factor
- if p > 0:
- sample_ratio_matrix = cum_samples[None, :].clamp(
- min=1) / cum_samples[:, None].clamp(min=1)
- index = (sample_ratio_matrix < 1.0).float()
- sample_weights = sample_ratio_matrix.pow(p) * index + (1 - index)
- mitigation_factor = sample_weights[labels.long(), :]
- seesaw_weights = seesaw_weights * mitigation_factor
- # compensation factor
- if q > 0:
- scores = F.softmax(cls_score.detach(), dim=1)
- self_scores = scores[
- torch.arange(0, len(scores)).to(scores.device).long(),
- labels.long()]
- score_matrix = scores / self_scores[:, None].clamp(min=eps)
- index = (score_matrix > 1.0).float()
- compensation_factor = score_matrix.pow(q) * index + (1 - index)
- seesaw_weights = seesaw_weights * compensation_factor
- cls_score = cls_score + (seesaw_weights.log() * (1 - onehot_labels))
- loss = F.cross_entropy(cls_score, labels, weight=None, reduction='none')
- if label_weights is not None:
- label_weights = label_weights.float()
- loss = weight_reduce_loss(
- loss, weight=label_weights, reduction=reduction, avg_factor=avg_factor)
- return loss
- @MODELS.register_module()
- class SeesawLoss(nn.Module):
- """
- Seesaw Loss for Long-Tailed Instance Segmentation (CVPR 2021)
- arXiv: https://arxiv.org/abs/2008.10032
- Args:
- use_sigmoid (bool, optional): Whether the prediction uses sigmoid
- of softmax. Only False is supported.
- p (float, optional): The ``p`` in the mitigation factor.
- Defaults to 0.8.
- q (float, optional): The ``q`` in the compenstation factor.
- Defaults to 2.0.
- num_classes (int, optional): The number of classes.
- Default to 1203 for LVIS v1 dataset.
- eps (float, optional): The minimal value of divisor to smooth
- the computation of compensation factor
- reduction (str, optional): The method that reduces the loss to a
- scalar. Options are "none", "mean" and "sum".
- loss_weight (float, optional): The weight of the loss. Defaults to 1.0
- return_dict (bool, optional): Whether return the losses as a dict.
- Default to True.
- """
- def __init__(self,
- use_sigmoid: bool = False,
- p: float = 0.8,
- q: float = 2.0,
- num_classes: int = 1203,
- eps: float = 1e-2,
- reduction: str = 'mean',
- loss_weight: float = 1.0,
- return_dict: bool = True) -> None:
- super().__init__()
- assert not use_sigmoid
- self.use_sigmoid = False
- self.p = p
- self.q = q
- self.num_classes = num_classes
- self.eps = eps
- self.reduction = reduction
- self.loss_weight = loss_weight
- self.return_dict = return_dict
- # 0 for pos, 1 for neg
- self.cls_criterion = seesaw_ce_loss
- # cumulative samples for each category
- self.register_buffer(
- 'cum_samples',
- torch.zeros(self.num_classes + 1, dtype=torch.float))
- # custom output channels of the classifier
- self.custom_cls_channels = True
- # custom activation of cls_score
- self.custom_activation = True
- # custom accuracy of the classsifier
- self.custom_accuracy = True
- def _split_cls_score(self, cls_score: Tensor) -> Tuple[Tensor, Tensor]:
- """split cls_score.
- Args:
- cls_score (Tensor): The prediction with shape (N, C + 2).
- Returns:
- Tuple[Tensor, Tensor]: The score for classes and objectness,
- respectively
- """
- # split cls_score to cls_score_classes and cls_score_objectness
- assert cls_score.size(-1) == self.num_classes + 2
- cls_score_classes = cls_score[..., :-2]
- cls_score_objectness = cls_score[..., -2:]
- return cls_score_classes, cls_score_objectness
- def get_cls_channels(self, num_classes: int) -> int:
- """Get custom classification channels.
- Args:
- num_classes (int): The number of classes.
- Returns:
- int: The custom classification channels.
- """
- assert num_classes == self.num_classes
- return num_classes + 2
- def get_activation(self, cls_score: Tensor) -> Tensor:
- """Get custom activation of cls_score.
- Args:
- cls_score (Tensor): The prediction with shape (N, C + 2).
- Returns:
- Tensor: The custom activation of cls_score with shape
- (N, C + 1).
- """
- cls_score_classes, cls_score_objectness = self._split_cls_score(
- cls_score)
- score_classes = F.softmax(cls_score_classes, dim=-1)
- score_objectness = F.softmax(cls_score_objectness, dim=-1)
- score_pos = score_objectness[..., [0]]
- score_neg = score_objectness[..., [1]]
- score_classes = score_classes * score_pos
- scores = torch.cat([score_classes, score_neg], dim=-1)
- return scores
- def get_accuracy(self, cls_score: Tensor,
- labels: Tensor) -> Dict[str, Tensor]:
- """Get custom accuracy w.r.t. cls_score and labels.
- Args:
- cls_score (Tensor): The prediction with shape (N, C + 2).
- labels (Tensor): The learning label of the prediction.
- Returns:
- Dict [str, Tensor]: The accuracy for objectness and classes,
- respectively.
- """
- pos_inds = labels < self.num_classes
- obj_labels = (labels == self.num_classes).long()
- cls_score_classes, cls_score_objectness = self._split_cls_score(
- cls_score)
- acc_objectness = accuracy(cls_score_objectness, obj_labels)
- acc_classes = accuracy(cls_score_classes[pos_inds], labels[pos_inds])
- acc = dict()
- acc['acc_objectness'] = acc_objectness
- acc['acc_classes'] = acc_classes
- return acc
- def forward(
- self,
- cls_score: Tensor,
- labels: Tensor,
- label_weights: Optional[Tensor] = None,
- avg_factor: Optional[int] = None,
- reduction_override: Optional[str] = None
- ) -> Union[Tensor, Dict[str, Tensor]]:
- """Forward function.
- Args:
- cls_score (Tensor): The prediction with shape (N, C + 2).
- labels (Tensor): The learning label of the prediction.
- label_weights (Tensor, optional): Sample-wise loss weight.
- avg_factor (int, optional): Average factor that is used to average
- the loss. Defaults to None.
- reduction (str, optional): The method used to reduce the loss.
- Options are "none", "mean" and "sum".
- Returns:
- Tensor | Dict [str, Tensor]:
- if return_dict == False: The calculated loss |
- if return_dict == True: The dict of calculated losses
- for objectness and classes, respectively.
- """
- assert reduction_override in (None, 'none', 'mean', 'sum')
- reduction = (
- reduction_override if reduction_override else self.reduction)
- assert cls_score.size(-1) == self.num_classes + 2
- pos_inds = labels < self.num_classes
- # 0 for pos, 1 for neg
- obj_labels = (labels == self.num_classes).long()
- # accumulate the samples for each category
- unique_labels = labels.unique()
- for u_l in unique_labels:
- inds_ = labels == u_l.item()
- self.cum_samples[u_l] += inds_.sum()
- if label_weights is not None:
- label_weights = label_weights.float()
- else:
- label_weights = labels.new_ones(labels.size(), dtype=torch.float)
- cls_score_classes, cls_score_objectness = self._split_cls_score(
- cls_score)
- # calculate loss_cls_classes (only need pos samples)
- if pos_inds.sum() > 0:
- loss_cls_classes = self.loss_weight * self.cls_criterion(
- cls_score_classes[pos_inds], labels[pos_inds],
- label_weights[pos_inds], self.cum_samples[:self.num_classes],
- self.num_classes, self.p, self.q, self.eps, reduction,
- avg_factor)
- else:
- loss_cls_classes = cls_score_classes[pos_inds].sum()
- # calculate loss_cls_objectness
- loss_cls_objectness = self.loss_weight * cross_entropy(
- cls_score_objectness, obj_labels, label_weights, reduction,
- avg_factor)
- if self.return_dict:
- loss_cls = dict()
- loss_cls['loss_cls_objectness'] = loss_cls_objectness
- loss_cls['loss_cls_classes'] = loss_cls_classes
- else:
- loss_cls = loss_cls_classes + loss_cls_objectness
- return loss_cls
|