1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Optional
- import torch.nn as nn
- import torch.nn.functional as F
- from torch import Tensor
- from mmdet.registry import MODELS
- from .utils import weighted_loss
- @weighted_loss
- def knowledge_distillation_kl_div_loss(pred: Tensor,
- soft_label: Tensor,
- T: int,
- detach_target: bool = True) -> Tensor:
- r"""Loss function for knowledge distilling using KL divergence.
- Args:
- pred (Tensor): Predicted logits with shape (N, n + 1).
- soft_label (Tensor): Target logits with shape (N, N + 1).
- T (int): Temperature for distillation.
- detach_target (bool): Remove soft_label from automatic differentiation
- Returns:
- Tensor: Loss tensor with shape (N,).
- """
- assert pred.size() == soft_label.size()
- target = F.softmax(soft_label / T, dim=1)
- if detach_target:
- target = target.detach()
- kd_loss = F.kl_div(
- F.log_softmax(pred / T, dim=1), target, reduction='none').mean(1) * (
- T * T)
- return kd_loss
- @MODELS.register_module()
- class KnowledgeDistillationKLDivLoss(nn.Module):
- """Loss function for knowledge distilling using KL divergence.
- Args:
- reduction (str): Options are `'none'`, `'mean'` and `'sum'`.
- loss_weight (float): Loss weight of current loss.
- T (int): Temperature for distillation.
- """
- def __init__(self,
- reduction: str = 'mean',
- loss_weight: float = 1.0,
- T: int = 10) -> None:
- super().__init__()
- assert T >= 1
- self.reduction = reduction
- self.loss_weight = loss_weight
- self.T = T
- def forward(self,
- pred: Tensor,
- soft_label: Tensor,
- weight: Optional[Tensor] = None,
- avg_factor: Optional[int] = None,
- reduction_override: Optional[str] = None) -> Tensor:
- """Forward function.
- Args:
- pred (Tensor): Predicted logits with shape (N, n + 1).
- soft_label (Tensor): Target logits with shape (N, N + 1).
- weight (Tensor, optional): The weight of loss for each
- prediction. Defaults to None.
- avg_factor (int, optional): Average factor that is used to average
- the loss. Defaults to None.
- reduction_override (str, optional): The reduction method used to
- override the original reduction method of the loss.
- Defaults to None.
- Returns:
- Tensor: Loss tensor.
- """
- assert reduction_override in (None, 'none', 'mean', 'sum')
- reduction = (
- reduction_override if reduction_override else self.reduction)
- loss_kd = self.loss_weight * knowledge_distillation_kl_div_loss(
- pred,
- soft_label,
- weight,
- reduction=reduction,
- avg_factor=avg_factor,
- T=self.T)
- return loss_kd
|