ema.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. from mmengine.model import ExponentialMovingAverage
  7. from torch import Tensor
  8. from mmdet.registry import MODELS
  9. @MODELS.register_module()
  10. class ExpMomentumEMA(ExponentialMovingAverage):
  11. """Exponential moving average (EMA) with exponential momentum strategy,
  12. which is used in YOLOX.
  13. Args:
  14. model (nn.Module): The model to be averaged.
  15. momentum (float): The momentum used for updating ema parameter.
  16. Ema's parameter are updated with the formula:
  17. `averaged_param = (1-momentum) * averaged_param + momentum *
  18. source_param`. Defaults to 0.0002.
  19. gamma (int): Use a larger momentum early in training and gradually
  20. annealing to a smaller value to update the ema model smoothly. The
  21. momentum is calculated as
  22. `(1 - momentum) * exp(-(1 + steps) / gamma) + momentum`.
  23. Defaults to 2000.
  24. interval (int): Interval between two updates. Defaults to 1.
  25. device (torch.device, optional): If provided, the averaged model will
  26. be stored on the :attr:`device`. Defaults to None.
  27. update_buffers (bool): if True, it will compute running averages for
  28. both the parameters and the buffers of the model. Defaults to
  29. False.
  30. """
  31. def __init__(self,
  32. model: nn.Module,
  33. momentum: float = 0.0002,
  34. gamma: int = 2000,
  35. interval=1,
  36. device: Optional[torch.device] = None,
  37. update_buffers: bool = False) -> None:
  38. super().__init__(
  39. model=model,
  40. momentum=momentum,
  41. interval=interval,
  42. device=device,
  43. update_buffers=update_buffers)
  44. assert gamma > 0, f'gamma must be greater than 0, but got {gamma}'
  45. self.gamma = gamma
  46. def avg_func(self, averaged_param: Tensor, source_param: Tensor,
  47. steps: int) -> None:
  48. """Compute the moving average of the parameters using the exponential
  49. momentum strategy.
  50. Args:
  51. averaged_param (Tensor): The averaged parameters.
  52. source_param (Tensor): The source parameters.
  53. steps (int): The number of times the parameters have been
  54. updated.
  55. """
  56. momentum = (1 - self.momentum) * math.exp(
  57. -float(1 + steps) / self.gamma) + self.momentum
  58. averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum)