test_checkloss_hook.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. from unittest.mock import Mock
  4. import torch
  5. from mmdet.engine.hooks import CheckInvalidLossHook
  6. class TestCheckInvalidLossHook(TestCase):
  7. def test_after_train_iter(self):
  8. n = 50
  9. hook = CheckInvalidLossHook(n)
  10. runner = Mock()
  11. runner.logger = Mock()
  12. runner.logger.info = Mock()
  13. # Test `after_train_iter` function within the n iteration.
  14. runner.iter = 10
  15. outputs = dict(loss=torch.LongTensor([2]))
  16. hook.after_train_iter(runner, 10, outputs=outputs)
  17. outputs = dict(loss=torch.tensor(float('nan')))
  18. hook.after_train_iter(runner, 10, outputs=outputs)
  19. outputs = dict(loss=torch.tensor(float('inf')))
  20. hook.after_train_iter(runner, 10, outputs=outputs)
  21. # Test `after_train_iter` at the n iteration.
  22. runner.iter = n - 1
  23. outputs = dict(loss=torch.LongTensor([2]))
  24. hook.after_train_iter(runner, n - 1, outputs=outputs)
  25. outputs = dict(loss=torch.tensor(float('nan')))
  26. with self.assertRaises(AssertionError):
  27. hook.after_train_iter(runner, n - 1, outputs=outputs)
  28. outputs = dict(loss=torch.tensor(float('inf')))
  29. with self.assertRaises(AssertionError):
  30. hook.after_train_iter(runner, n - 1, outputs=outputs)