_fast_stop_training_hook.py 964 B

123456789101112131415161718192021222324252627
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from mmengine.hooks import Hook
  3. from mmdet.registry import HOOKS
  4. @HOOKS.register_module()
  5. class FastStopTrainingHook(Hook):
  6. """Set runner's epoch information to the model."""
  7. def __init__(self, by_epoch, save_ckpt=False, stop_iter_or_epoch=5):
  8. self.by_epoch = by_epoch
  9. self.save_ckpt = save_ckpt
  10. self.stop_iter_or_epoch = stop_iter_or_epoch
  11. def after_train_iter(self, runner, batch_idx: int, data_batch: None,
  12. outputs: None) -> None:
  13. if self.save_ckpt and self.by_epoch:
  14. # If it is epoch-based and want to save weights,
  15. # we must run at least 1 epoch.
  16. return
  17. if runner.iter >= self.stop_iter_or_epoch:
  18. raise RuntimeError('quick exit')
  19. def after_train_epoch(self, runner) -> None:
  20. if runner.epoch >= self.stop_iter_or_epoch - 1:
  21. raise RuntimeError('quick exit')