checkloss_hook.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional
  3. import torch
  4. from mmengine.hooks import Hook
  5. from mmengine.runner import Runner
  6. from mmdet.registry import HOOKS
  7. @HOOKS.register_module()
  8. class CheckInvalidLossHook(Hook):
  9. """Check invalid loss hook.
  10. This hook will regularly check whether the loss is valid
  11. during training.
  12. Args:
  13. interval (int): Checking interval (every k iterations).
  14. Default: 50.
  15. """
  16. def __init__(self, interval: int = 50) -> None:
  17. self.interval = interval
  18. def after_train_iter(self,
  19. runner: Runner,
  20. batch_idx: int,
  21. data_batch: Optional[dict] = None,
  22. outputs: Optional[dict] = None) -> None:
  23. """Regularly check whether the loss is valid every n iterations.
  24. Args:
  25. runner (:obj:`Runner`): The runner of the training process.
  26. batch_idx (int): The index of the current batch in the train loop.
  27. data_batch (dict, Optional): Data from dataloader.
  28. Defaults to None.
  29. outputs (dict, Optional): Outputs from model. Defaults to None.
  30. """
  31. if self.every_n_train_iters(runner, self.interval):
  32. assert torch.isfinite(outputs['loss']), \
  33. runner.logger.info('loss become infinite or NaN!')