mean_teacher_hook.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional
  3. import torch.nn as nn
  4. from mmengine.hooks import Hook
  5. from mmengine.model import is_model_wrapper
  6. from mmengine.runner import Runner
  7. from mmdet.registry import HOOKS
  8. @HOOKS.register_module()
  9. class MeanTeacherHook(Hook):
  10. """Mean Teacher Hook.
  11. Mean Teacher is an efficient semi-supervised learning method in
  12. `Mean Teacher <https://arxiv.org/abs/1703.01780>`_.
  13. This method requires two models with exactly the same structure,
  14. as the student model and the teacher model, respectively.
  15. The student model updates the parameters through gradient descent,
  16. and the teacher model updates the parameters through
  17. exponential moving average of the student model.
  18. Compared with the student model, the teacher model
  19. is smoother and accumulates more knowledge.
  20. Args:
  21. momentum (float): The momentum used for updating teacher's parameter.
  22. Teacher's parameter are updated with the formula:
  23. `teacher = (1-momentum) * teacher + momentum * student`.
  24. Defaults to 0.001.
  25. interval (int): Update teacher's parameter every interval iteration.
  26. Defaults to 1.
  27. skip_buffers (bool): Whether to skip the model buffers, such as
  28. batchnorm running stats (running_mean, running_var), it does not
  29. perform the ema operation. Default to True.
  30. """
  31. def __init__(self,
  32. momentum: float = 0.001,
  33. interval: int = 1,
  34. skip_buffer=True) -> None:
  35. assert 0 < momentum < 1
  36. self.momentum = momentum
  37. self.interval = interval
  38. self.skip_buffers = skip_buffer
  39. def before_train(self, runner: Runner) -> None:
  40. """To check that teacher model and student model exist."""
  41. model = runner.model
  42. if is_model_wrapper(model):
  43. model = model.module
  44. assert hasattr(model, 'teacher')
  45. assert hasattr(model, 'student')
  46. # only do it at initial stage
  47. if runner.iter == 0:
  48. self.momentum_update(model, 1)
  49. def after_train_iter(self,
  50. runner: Runner,
  51. batch_idx: int,
  52. data_batch: Optional[dict] = None,
  53. outputs: Optional[dict] = None) -> None:
  54. """Update teacher's parameter every self.interval iterations."""
  55. if (runner.iter + 1) % self.interval != 0:
  56. return
  57. model = runner.model
  58. if is_model_wrapper(model):
  59. model = model.module
  60. self.momentum_update(model, self.momentum)
  61. def momentum_update(self, model: nn.Module, momentum: float) -> None:
  62. """Compute the moving average of the parameters using exponential
  63. moving average."""
  64. if self.skip_buffers:
  65. for (src_name, src_parm), (dst_name, dst_parm) in zip(
  66. model.student.named_parameters(),
  67. model.teacher.named_parameters()):
  68. dst_parm.data.mul_(1 - momentum).add_(
  69. src_parm.data, alpha=momentum)
  70. else:
  71. for (src_parm,
  72. dst_parm) in zip(model.student.state_dict().values(),
  73. model.teacher.state_dict().values()):
  74. # exclude num_tracking
  75. if dst_parm.dtype.is_floating_point:
  76. dst_parm.data.mul_(1 - momentum).add_(
  77. src_parm.data, alpha=momentum)