123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Optional
- import torch.nn as nn
- from mmengine.hooks import Hook
- from mmengine.model import is_model_wrapper
- from mmengine.runner import Runner
- from mmdet.registry import HOOKS
- @HOOKS.register_module()
- class MeanTeacherHook(Hook):
- """Mean Teacher Hook.
- Mean Teacher is an efficient semi-supervised learning method in
- `Mean Teacher <https://arxiv.org/abs/1703.01780>`_.
- This method requires two models with exactly the same structure,
- as the student model and the teacher model, respectively.
- The student model updates the parameters through gradient descent,
- and the teacher model updates the parameters through
- exponential moving average of the student model.
- Compared with the student model, the teacher model
- is smoother and accumulates more knowledge.
- Args:
- momentum (float): The momentum used for updating teacher's parameter.
- Teacher's parameter are updated with the formula:
- `teacher = (1-momentum) * teacher + momentum * student`.
- Defaults to 0.001.
- interval (int): Update teacher's parameter every interval iteration.
- Defaults to 1.
- skip_buffers (bool): Whether to skip the model buffers, such as
- batchnorm running stats (running_mean, running_var), it does not
- perform the ema operation. Default to True.
- """
- def __init__(self,
- momentum: float = 0.001,
- interval: int = 1,
- skip_buffer=True) -> None:
- assert 0 < momentum < 1
- self.momentum = momentum
- self.interval = interval
- self.skip_buffers = skip_buffer
- def before_train(self, runner: Runner) -> None:
- """To check that teacher model and student model exist."""
- model = runner.model
- if is_model_wrapper(model):
- model = model.module
- assert hasattr(model, 'teacher')
- assert hasattr(model, 'student')
- # only do it at initial stage
- if runner.iter == 0:
- self.momentum_update(model, 1)
- def after_train_iter(self,
- runner: Runner,
- batch_idx: int,
- data_batch: Optional[dict] = None,
- outputs: Optional[dict] = None) -> None:
- """Update teacher's parameter every self.interval iterations."""
- if (runner.iter + 1) % self.interval != 0:
- return
- model = runner.model
- if is_model_wrapper(model):
- model = model.module
- self.momentum_update(model, self.momentum)
- def momentum_update(self, model: nn.Module, momentum: float) -> None:
- """Compute the moving average of the parameters using exponential
- moving average."""
- if self.skip_buffers:
- for (src_name, src_parm), (dst_name, dst_parm) in zip(
- model.student.named_parameters(),
- model.teacher.named_parameters()):
- dst_parm.data.mul_(1 - momentum).add_(
- src_parm.data, alpha=momentum)
- else:
- for (src_parm,
- dst_parm) in zip(model.student.state_dict().values(),
- model.teacher.state_dict().values()):
- # exclude num_tracking
- if dst_parm.dtype.is_floating_point:
- dst_parm.data.mul_(1 - momentum).add_(
- src_parm.data, alpha=momentum)
|