test_classification_losses.py 621 B

12345678910111213141516171819202122
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmpose.models.losses.classification_loss import InfoNCELoss
  5. class TestInfoNCELoss(TestCase):
  6. def test_loss(self):
  7. # test loss w/o target_weight
  8. loss = InfoNCELoss(temperature=0.05)
  9. fake_pred = torch.arange(5 * 2).reshape(5, 2).float()
  10. self.assertTrue(
  11. torch.allclose(loss(fake_pred), torch.tensor(5.4026), atol=1e-4))
  12. # check if the value of temperature is positive
  13. with self.assertRaises(AssertionError):
  14. loss = InfoNCELoss(temperature=0.)