12345678910111213141516171819202122232425262728293031323334353637 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from unittest import TestCase
- from unittest.mock import Mock
- import torch
- from mmdet.engine.hooks import CheckInvalidLossHook
- class TestCheckInvalidLossHook(TestCase):
- def test_after_train_iter(self):
- n = 50
- hook = CheckInvalidLossHook(n)
- runner = Mock()
- runner.logger = Mock()
- runner.logger.info = Mock()
- # Test `after_train_iter` function within the n iteration.
- runner.iter = 10
- outputs = dict(loss=torch.LongTensor([2]))
- hook.after_train_iter(runner, 10, outputs=outputs)
- outputs = dict(loss=torch.tensor(float('nan')))
- hook.after_train_iter(runner, 10, outputs=outputs)
- outputs = dict(loss=torch.tensor(float('inf')))
- hook.after_train_iter(runner, 10, outputs=outputs)
- # Test `after_train_iter` at the n iteration.
- runner.iter = n - 1
- outputs = dict(loss=torch.LongTensor([2]))
- hook.after_train_iter(runner, n - 1, outputs=outputs)
- outputs = dict(loss=torch.tensor(float('nan')))
- with self.assertRaises(AssertionError):
- hook.after_train_iter(runner, n - 1, outputs=outputs)
- outputs = dict(loss=torch.tensor(float('inf')))
- with self.assertRaises(AssertionError):
- hook.after_train_iter(runner, n - 1, outputs=outputs)
|