test_regression_losses.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmpose.models.losses.regression_loss import SoftWeightSmoothL1Loss
  5. class TestSoftWeightSmoothL1Loss(TestCase):
  6. def test_loss(self):
  7. # test loss w/o target_weight
  8. loss = SoftWeightSmoothL1Loss(use_target_weight=False, beta=0.5)
  9. fake_pred = torch.zeros((1, 3, 2))
  10. fake_label = torch.zeros((1, 3, 2))
  11. self.assertTrue(
  12. torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.)))
  13. fake_pred = torch.ones((1, 3, 2))
  14. fake_label = torch.zeros((1, 3, 2))
  15. self.assertTrue(
  16. torch.allclose(loss(fake_pred, fake_label), torch.tensor(.75)))
  17. # test loss w/ target_weight
  18. loss = SoftWeightSmoothL1Loss(
  19. use_target_weight=True, supervise_empty=True)
  20. fake_pred = torch.ones((1, 3, 2))
  21. fake_label = torch.zeros((1, 3, 2))
  22. fake_weight = torch.arange(6).reshape(1, 3, 2).float()
  23. self.assertTrue(
  24. torch.allclose(
  25. loss(fake_pred, fake_label, fake_weight), torch.tensor(1.25)))
  26. # test loss that does not take empty channels into account
  27. loss = SoftWeightSmoothL1Loss(
  28. use_target_weight=True, supervise_empty=False)
  29. self.assertTrue(
  30. torch.allclose(
  31. loss(fake_pred, fake_label, fake_weight), torch.tensor(1.5)))
  32. with self.assertRaises(ValueError):
  33. _ = loss.smooth_l1_loss(fake_pred, fake_label, reduction='fake')
  34. output = loss.smooth_l1_loss(fake_pred, fake_label, reduction='sum')
  35. self.assertTrue(torch.allclose(output, torch.tensor(3.0)))