1234567891011121314151617181920212223242526272829303132333435363738 |
- import unittest
- import torch
- from mmdet.models.losses import GaussianFocalLoss
- class TestGaussianFocalLoss(unittest.TestCase):
- def test_forward(self):
- pred = torch.rand((10, 4))
- target = torch.rand((10, 4))
- gaussian_focal_loss = GaussianFocalLoss()
- loss1 = gaussian_focal_loss(pred, target)
- self.assertIsInstance(loss1, torch.Tensor)
- loss2 = gaussian_focal_loss(pred, target, avg_factor=0.5)
- self.assertIsInstance(loss2, torch.Tensor)
- # test reduction
- gaussian_focal_loss = GaussianFocalLoss(reduction='none')
- loss = gaussian_focal_loss(pred, target)
- self.assertTrue(loss.shape == (10, 4))
- # test reduction_override
- loss = gaussian_focal_loss(pred, target, reduction_override='mean')
- self.assertTrue(loss.ndim == 0)
- # Only supports None, 'none', 'mean', 'sum'
- with self.assertRaises(AssertionError):
- gaussian_focal_loss(pred, target, reduction_override='max')
- # test pos_inds
- pos_inds = (torch.rand(5) * 8).long()
- pos_labels = (torch.rand(5) * 2).long()
- gaussian_focal_loss = GaussianFocalLoss()
- loss = gaussian_focal_loss(pred, target, pos_inds, pos_labels)
- self.assertIsInstance(loss, torch.Tensor)
|