123456789101112131415161718192021222324252627282930313233343536373839404142434445464748 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from unittest import TestCase
- import torch
- from mmpose.models.losses.regression_loss import SoftWeightSmoothL1Loss
- class TestSoftWeightSmoothL1Loss(TestCase):
- def test_loss(self):
- # test loss w/o target_weight
- loss = SoftWeightSmoothL1Loss(use_target_weight=False, beta=0.5)
- fake_pred = torch.zeros((1, 3, 2))
- fake_label = torch.zeros((1, 3, 2))
- self.assertTrue(
- torch.allclose(loss(fake_pred, fake_label), torch.tensor(0.)))
- fake_pred = torch.ones((1, 3, 2))
- fake_label = torch.zeros((1, 3, 2))
- self.assertTrue(
- torch.allclose(loss(fake_pred, fake_label), torch.tensor(.75)))
- # test loss w/ target_weight
- loss = SoftWeightSmoothL1Loss(
- use_target_weight=True, supervise_empty=True)
- fake_pred = torch.ones((1, 3, 2))
- fake_label = torch.zeros((1, 3, 2))
- fake_weight = torch.arange(6).reshape(1, 3, 2).float()
- self.assertTrue(
- torch.allclose(
- loss(fake_pred, fake_label, fake_weight), torch.tensor(1.25)))
- # test loss that does not take empty channels into account
- loss = SoftWeightSmoothL1Loss(
- use_target_weight=True, supervise_empty=False)
- self.assertTrue(
- torch.allclose(
- loss(fake_pred, fake_label, fake_weight), torch.tensor(1.5)))
- with self.assertRaises(ValueError):
- _ = loss.smooth_l1_loss(fake_pred, fake_label, reduction='fake')
- output = loss.smooth_l1_loss(fake_pred, fake_label, reduction='sum')
- self.assertTrue(torch.allclose(output, torch.tensor(3.0)))
|