test_gaussian_focal_loss.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import unittest
  2. import torch
  3. from mmdet.models.losses import GaussianFocalLoss
  4. class TestGaussianFocalLoss(unittest.TestCase):
  5. def test_forward(self):
  6. pred = torch.rand((10, 4))
  7. target = torch.rand((10, 4))
  8. gaussian_focal_loss = GaussianFocalLoss()
  9. loss1 = gaussian_focal_loss(pred, target)
  10. self.assertIsInstance(loss1, torch.Tensor)
  11. loss2 = gaussian_focal_loss(pred, target, avg_factor=0.5)
  12. self.assertIsInstance(loss2, torch.Tensor)
  13. # test reduction
  14. gaussian_focal_loss = GaussianFocalLoss(reduction='none')
  15. loss = gaussian_focal_loss(pred, target)
  16. self.assertTrue(loss.shape == (10, 4))
  17. # test reduction_override
  18. loss = gaussian_focal_loss(pred, target, reduction_override='mean')
  19. self.assertTrue(loss.ndim == 0)
  20. # Only supports None, 'none', 'mean', 'sum'
  21. with self.assertRaises(AssertionError):
  22. gaussian_focal_loss(pred, target, reduction_override='max')
  23. # test pos_inds
  24. pos_inds = (torch.rand(5) * 8).long()
  25. pos_labels = (torch.rand(5) * 2).long()
  26. gaussian_focal_loss = GaussianFocalLoss()
  27. loss = gaussian_focal_loss(pred, target, pos_inds, pos_labels)
  28. self.assertIsInstance(loss, torch.Tensor)