kd_loss.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch import Tensor
  6. from mmdet.registry import MODELS
  7. from .utils import weighted_loss
  8. @weighted_loss
  9. def knowledge_distillation_kl_div_loss(pred: Tensor,
  10. soft_label: Tensor,
  11. T: int,
  12. detach_target: bool = True) -> Tensor:
  13. r"""Loss function for knowledge distilling using KL divergence.
  14. Args:
  15. pred (Tensor): Predicted logits with shape (N, n + 1).
  16. soft_label (Tensor): Target logits with shape (N, N + 1).
  17. T (int): Temperature for distillation.
  18. detach_target (bool): Remove soft_label from automatic differentiation
  19. Returns:
  20. Tensor: Loss tensor with shape (N,).
  21. """
  22. assert pred.size() == soft_label.size()
  23. target = F.softmax(soft_label / T, dim=1)
  24. if detach_target:
  25. target = target.detach()
  26. kd_loss = F.kl_div(
  27. F.log_softmax(pred / T, dim=1), target, reduction='none').mean(1) * (
  28. T * T)
  29. return kd_loss
  30. @MODELS.register_module()
  31. class KnowledgeDistillationKLDivLoss(nn.Module):
  32. """Loss function for knowledge distilling using KL divergence.
  33. Args:
  34. reduction (str): Options are `'none'`, `'mean'` and `'sum'`.
  35. loss_weight (float): Loss weight of current loss.
  36. T (int): Temperature for distillation.
  37. """
  38. def __init__(self,
  39. reduction: str = 'mean',
  40. loss_weight: float = 1.0,
  41. T: int = 10) -> None:
  42. super().__init__()
  43. assert T >= 1
  44. self.reduction = reduction
  45. self.loss_weight = loss_weight
  46. self.T = T
  47. def forward(self,
  48. pred: Tensor,
  49. soft_label: Tensor,
  50. weight: Optional[Tensor] = None,
  51. avg_factor: Optional[int] = None,
  52. reduction_override: Optional[str] = None) -> Tensor:
  53. """Forward function.
  54. Args:
  55. pred (Tensor): Predicted logits with shape (N, n + 1).
  56. soft_label (Tensor): Target logits with shape (N, N + 1).
  57. weight (Tensor, optional): The weight of loss for each
  58. prediction. Defaults to None.
  59. avg_factor (int, optional): Average factor that is used to average
  60. the loss. Defaults to None.
  61. reduction_override (str, optional): The reduction method used to
  62. override the original reduction method of the loss.
  63. Defaults to None.
  64. Returns:
  65. Tensor: Loss tensor.
  66. """
  67. assert reduction_override in (None, 'none', 'mean', 'sum')
  68. reduction = (
  69. reduction_override if reduction_override else self.reduction)
  70. loss_kd = self.loss_weight * knowledge_distillation_kl_div_loss(
  71. pred,
  72. soft_label,
  73. weight,
  74. reduction=reduction,
  75. avg_factor=avg_factor,
  76. T=self.T)
  77. return loss_kd