set_epoch_info_hook.py 480 B

1234567891011121314151617
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from mmengine.hooks import Hook
  3. from mmengine.model.wrappers import is_model_wrapper
  4. from mmdet.registry import HOOKS
  5. @HOOKS.register_module()
  6. class SetEpochInfoHook(Hook):
  7. """Set runner's epoch information to the model."""
  8. def before_train_epoch(self, runner):
  9. epoch = runner.epoch
  10. model = runner.model
  11. if is_model_wrapper(model):
  12. model = model.module
  13. model.set_epoch(epoch)