123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import math
- from typing import Optional
- import torch
- import torch.nn as nn
- from mmengine.model import ExponentialMovingAverage
- from torch import Tensor
- from mmdet.registry import MODELS
- @MODELS.register_module()
- class ExpMomentumEMA(ExponentialMovingAverage):
- """Exponential moving average (EMA) with exponential momentum strategy,
- which is used in YOLOX.
- Args:
- model (nn.Module): The model to be averaged.
- momentum (float): The momentum used for updating ema parameter.
- Ema's parameter are updated with the formula:
- `averaged_param = (1-momentum) * averaged_param + momentum *
- source_param`. Defaults to 0.0002.
- gamma (int): Use a larger momentum early in training and gradually
- annealing to a smaller value to update the ema model smoothly. The
- momentum is calculated as
- `(1 - momentum) * exp(-(1 + steps) / gamma) + momentum`.
- Defaults to 2000.
- interval (int): Interval between two updates. Defaults to 1.
- device (torch.device, optional): If provided, the averaged model will
- be stored on the :attr:`device`. Defaults to None.
- update_buffers (bool): if True, it will compute running averages for
- both the parameters and the buffers of the model. Defaults to
- False.
- """
- def __init__(self,
- model: nn.Module,
- momentum: float = 0.0002,
- gamma: int = 2000,
- interval=1,
- device: Optional[torch.device] = None,
- update_buffers: bool = False) -> None:
- super().__init__(
- model=model,
- momentum=momentum,
- interval=interval,
- device=device,
- update_buffers=update_buffers)
- assert gamma > 0, f'gamma must be greater than 0, but got {gamma}'
- self.gamma = gamma
- def avg_func(self, averaged_param: Tensor, source_param: Tensor,
- steps: int) -> None:
- """Compute the moving average of the parameters using the exponential
- momentum strategy.
- Args:
- averaged_param (Tensor): The averaged parameters.
- source_param (Tensor): The source parameters.
- steps (int): The number of times the parameters have been
- updated.
- """
- momentum = (1 - self.momentum) * math.exp(
- -float(1 + steps) / self.gamma) + self.momentum
- averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum)
|