123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import warnings
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from mmdet.registry import MODELS
- from .utils import weight_reduce_loss
- def cross_entropy(pred,
- label,
- weight=None,
- reduction='mean',
- avg_factor=None,
- class_weight=None,
- ignore_index=-100,
- avg_non_ignore=False):
- """Calculate the CrossEntropy loss.
- Args:
- pred (torch.Tensor): The prediction with shape (N, C), C is the number
- of classes.
- label (torch.Tensor): The learning label of the prediction.
- weight (torch.Tensor, optional): Sample-wise loss weight.
- reduction (str, optional): The method used to reduce the loss.
- avg_factor (int, optional): Average factor that is used to average
- the loss. Defaults to None.
- class_weight (list[float], optional): The weight for each class.
- ignore_index (int | None): The label index to be ignored.
- If None, it will be set to default value. Default: -100.
- avg_non_ignore (bool): The flag decides to whether the loss is
- only averaged over non-ignored targets. Default: False.
- Returns:
- torch.Tensor: The calculated loss
- """
- # The default value of ignore_index is the same as F.cross_entropy
- ignore_index = -100 if ignore_index is None else ignore_index
- # element-wise losses
- loss = F.cross_entropy(
- pred,
- label,
- weight=class_weight,
- reduction='none',
- ignore_index=ignore_index)
- # average loss over non-ignored elements
- # pytorch's official cross_entropy average loss over non-ignored elements
- # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
- if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
- avg_factor = label.numel() - (label == ignore_index).sum().item()
- # apply weights and do the reduction
- if weight is not None:
- weight = weight.float()
- loss = weight_reduce_loss(
- loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
- return loss
- def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index):
- """Expand onehot labels to match the size of prediction."""
- bin_labels = labels.new_full((labels.size(0), label_channels), 0)
- valid_mask = (labels >= 0) & (labels != ignore_index)
- inds = torch.nonzero(
- valid_mask & (labels < label_channels), as_tuple=False)
- if inds.numel() > 0:
- bin_labels[inds, labels[inds]] = 1
- valid_mask = valid_mask.view(-1, 1).expand(labels.size(0),
- label_channels).float()
- if label_weights is None:
- bin_label_weights = valid_mask
- else:
- bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels)
- bin_label_weights *= valid_mask
- return bin_labels, bin_label_weights, valid_mask
- def binary_cross_entropy(pred,
- label,
- weight=None,
- reduction='mean',
- avg_factor=None,
- class_weight=None,
- ignore_index=-100,
- avg_non_ignore=False):
- """Calculate the binary CrossEntropy loss.
- Args:
- pred (torch.Tensor): The prediction with shape (N, 1) or (N, ).
- When the shape of pred is (N, 1), label will be expanded to
- one-hot format, and when the shape of pred is (N, ), label
- will not be expanded to one-hot format.
- label (torch.Tensor): The learning label of the prediction,
- with shape (N, ).
- weight (torch.Tensor, optional): Sample-wise loss weight.
- reduction (str, optional): The method used to reduce the loss.
- Options are "none", "mean" and "sum".
- avg_factor (int, optional): Average factor that is used to average
- the loss. Defaults to None.
- class_weight (list[float], optional): The weight for each class.
- ignore_index (int | None): The label index to be ignored.
- If None, it will be set to default value. Default: -100.
- avg_non_ignore (bool): The flag decides to whether the loss is
- only averaged over non-ignored targets. Default: False.
- Returns:
- torch.Tensor: The calculated loss.
- """
- # The default value of ignore_index is the same as F.cross_entropy
- ignore_index = -100 if ignore_index is None else ignore_index
- if pred.dim() != label.dim():
- label, weight, valid_mask = _expand_onehot_labels(
- label, weight, pred.size(-1), ignore_index)
- else:
- # should mask out the ignored elements
- valid_mask = ((label >= 0) & (label != ignore_index)).float()
- if weight is not None:
- # The inplace writing method will have a mismatched broadcast
- # shape error if the weight and valid_mask dimensions
- # are inconsistent such as (B,N,1) and (B,N,C).
- weight = weight * valid_mask
- else:
- weight = valid_mask
- # average loss over non-ignored elements
- if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
- avg_factor = valid_mask.sum().item()
- # weighted element-wise losses
- weight = weight.float()
- loss = F.binary_cross_entropy_with_logits(
- pred, label.float(), pos_weight=class_weight, reduction='none')
- # do the reduction for the weighted loss
- loss = weight_reduce_loss(
- loss, weight, reduction=reduction, avg_factor=avg_factor)
- return loss
- def mask_cross_entropy(pred,
- target,
- label,
- reduction='mean',
- avg_factor=None,
- class_weight=None,
- ignore_index=None,
- **kwargs):
- """Calculate the CrossEntropy loss for masks.
- Args:
- pred (torch.Tensor): The prediction with shape (N, C, *), C is the
- number of classes. The trailing * indicates arbitrary shape.
- target (torch.Tensor): The learning label of the prediction.
- label (torch.Tensor): ``label`` indicates the class label of the mask
- corresponding object. This will be used to select the mask in the
- of the class which the object belongs to when the mask prediction
- if not class-agnostic.
- reduction (str, optional): The method used to reduce the loss.
- Options are "none", "mean" and "sum".
- avg_factor (int, optional): Average factor that is used to average
- the loss. Defaults to None.
- class_weight (list[float], optional): The weight for each class.
- ignore_index (None): Placeholder, to be consistent with other loss.
- Default: None.
- Returns:
- torch.Tensor: The calculated loss
- Example:
- >>> N, C = 3, 11
- >>> H, W = 2, 2
- >>> pred = torch.randn(N, C, H, W) * 1000
- >>> target = torch.rand(N, H, W)
- >>> label = torch.randint(0, C, size=(N,))
- >>> reduction = 'mean'
- >>> avg_factor = None
- >>> class_weights = None
- >>> loss = mask_cross_entropy(pred, target, label, reduction,
- >>> avg_factor, class_weights)
- >>> assert loss.shape == (1,)
- """
- assert ignore_index is None, 'BCE loss does not support ignore_index'
- # TODO: handle these two reserved arguments
- assert reduction == 'mean' and avg_factor is None
- num_rois = pred.size()[0]
- inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
- pred_slice = pred[inds, label].squeeze(1)
- return F.binary_cross_entropy_with_logits(
- pred_slice, target, weight=class_weight, reduction='mean')[None]
- @MODELS.register_module()
- class CrossEntropyLoss(nn.Module):
- def __init__(self,
- use_sigmoid=False,
- use_mask=False,
- reduction='mean',
- class_weight=None,
- ignore_index=None,
- loss_weight=1.0,
- avg_non_ignore=False):
- """CrossEntropyLoss.
- Args:
- use_sigmoid (bool, optional): Whether the prediction uses sigmoid
- of softmax. Defaults to False.
- use_mask (bool, optional): Whether to use mask cross entropy loss.
- Defaults to False.
- reduction (str, optional): . Defaults to 'mean'.
- Options are "none", "mean" and "sum".
- class_weight (list[float], optional): Weight of each class.
- Defaults to None.
- ignore_index (int | None): The label index to be ignored.
- Defaults to None.
- loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
- avg_non_ignore (bool): The flag decides to whether the loss is
- only averaged over non-ignored targets. Default: False.
- """
- super(CrossEntropyLoss, self).__init__()
- assert (use_sigmoid is False) or (use_mask is False)
- self.use_sigmoid = use_sigmoid
- self.use_mask = use_mask
- self.reduction = reduction
- self.loss_weight = loss_weight
- self.class_weight = class_weight
- self.ignore_index = ignore_index
- self.avg_non_ignore = avg_non_ignore
- if ((ignore_index is not None) and not self.avg_non_ignore
- and self.reduction == 'mean'):
- warnings.warn(
- 'Default ``avg_non_ignore`` is False, if you would like to '
- 'ignore the certain label and average loss over non-ignore '
- 'labels, which is the same with PyTorch official '
- 'cross_entropy, set ``avg_non_ignore=True``.')
- if self.use_sigmoid:
- self.cls_criterion = binary_cross_entropy
- elif self.use_mask:
- self.cls_criterion = mask_cross_entropy
- else:
- self.cls_criterion = cross_entropy
- def extra_repr(self):
- """Extra repr."""
- s = f'avg_non_ignore={self.avg_non_ignore}'
- return s
- def forward(self,
- cls_score,
- label,
- weight=None,
- avg_factor=None,
- reduction_override=None,
- ignore_index=None,
- **kwargs):
- """Forward function.
- Args:
- cls_score (torch.Tensor): The prediction.
- label (torch.Tensor): The learning label of the prediction.
- weight (torch.Tensor, optional): Sample-wise loss weight.
- avg_factor (int, optional): Average factor that is used to average
- the loss. Defaults to None.
- reduction_override (str, optional): The method used to reduce the
- loss. Options are "none", "mean" and "sum".
- ignore_index (int | None): The label index to be ignored.
- If not None, it will override the default value. Default: None.
- Returns:
- torch.Tensor: The calculated loss.
- """
- assert reduction_override in (None, 'none', 'mean', 'sum')
- reduction = (
- reduction_override if reduction_override else self.reduction)
- if ignore_index is None:
- ignore_index = self.ignore_index
- if self.class_weight is not None:
- class_weight = cls_score.new_tensor(
- self.class_weight, device=cls_score.device)
- else:
- class_weight = None
- loss_cls = self.loss_weight * self.cls_criterion(
- cls_score,
- label,
- weight,
- class_weight=class_weight,
- reduction=reduction,
- avg_factor=avg_factor,
- ignore_index=ignore_index,
- avg_non_ignore=self.avg_non_ignore,
- **kwargs)
- return loss_cls
|