123456789101112131415161718192021222324252627282930313233343536373839404142 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Optional
- import torch
- from mmengine.hooks import Hook
- from mmengine.runner import Runner
- from mmdet.registry import HOOKS
- @HOOKS.register_module()
- class CheckInvalidLossHook(Hook):
- """Check invalid loss hook.
- This hook will regularly check whether the loss is valid
- during training.
- Args:
- interval (int): Checking interval (every k iterations).
- Default: 50.
- """
- def __init__(self, interval: int = 50) -> None:
- self.interval = interval
- def after_train_iter(self,
- runner: Runner,
- batch_idx: int,
- data_batch: Optional[dict] = None,
- outputs: Optional[dict] = None) -> None:
- """Regularly check whether the loss is valid every n iterations.
- Args:
- runner (:obj:`Runner`): The runner of the training process.
- batch_idx (int): The index of the current batch in the train loop.
- data_batch (dict, Optional): Data from dataloader.
- Defaults to None.
- outputs (dict, Optional): Outputs from model. Defaults to None.
- """
- if self.every_n_train_iters(runner, self.interval):
- assert torch.isfinite(outputs['loss']), \
- runner.logger.info('loss become infinite or NaN!')
|