test_loss.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. import torch.nn.functional as F
  5. from mmengine.utils import digit_version
  6. from mmdet.models.losses import (BalancedL1Loss, CrossEntropyLoss, DiceLoss,
  7. DistributionFocalLoss, FocalLoss,
  8. GaussianFocalLoss,
  9. KnowledgeDistillationKLDivLoss, L1Loss,
  10. MSELoss, QualityFocalLoss, SeesawLoss,
  11. SmoothL1Loss, VarifocalLoss)
  12. from mmdet.models.losses.ghm_loss import GHMC, GHMR
  13. from mmdet.models.losses.iou_loss import (BoundedIoULoss, CIoULoss, DIoULoss,
  14. EIoULoss, GIoULoss, IoULoss)
  15. @pytest.mark.parametrize(
  16. 'loss_class',
  17. [IoULoss, BoundedIoULoss, GIoULoss, DIoULoss, CIoULoss, EIoULoss])
  18. def test_iou_type_loss_zeros_weight(loss_class):
  19. pred = torch.rand((10, 4))
  20. target = torch.rand((10, 4))
  21. weight = torch.zeros(10)
  22. loss = loss_class()(pred, target, weight)
  23. assert loss == 0.
  24. @pytest.mark.parametrize('loss_class', [
  25. BalancedL1Loss, BoundedIoULoss, CIoULoss, CrossEntropyLoss, DIoULoss,
  26. EIoULoss, FocalLoss, DistributionFocalLoss, MSELoss, SeesawLoss,
  27. GaussianFocalLoss, GIoULoss, QualityFocalLoss, IoULoss, L1Loss,
  28. VarifocalLoss, GHMR, GHMC, SmoothL1Loss, KnowledgeDistillationKLDivLoss,
  29. DiceLoss
  30. ])
  31. def test_loss_with_reduction_override(loss_class):
  32. pred = torch.rand((10, 4))
  33. target = torch.rand((10, 4)),
  34. weight = None
  35. with pytest.raises(AssertionError):
  36. # only reduction_override from [None, 'none', 'mean', 'sum']
  37. # is not allowed
  38. reduction_override = True
  39. loss_class()(
  40. pred, target, weight, reduction_override=reduction_override)
  41. @pytest.mark.parametrize('loss_class', [QualityFocalLoss])
  42. @pytest.mark.parametrize('activated', [False, True])
  43. def test_QualityFocalLoss_Loss(loss_class, activated):
  44. input_shape = (4, 5)
  45. pred = torch.rand(input_shape)
  46. label = torch.Tensor([0, 1, 2, 0]).long()
  47. quality_label = torch.rand(input_shape[0])
  48. original_loss = loss_class(activated=activated)(pred,
  49. (label, quality_label))
  50. assert isinstance(original_loss, torch.Tensor)
  51. target = torch.nn.functional.one_hot(label, 5)
  52. target = target * quality_label.reshape(input_shape[0], 1)
  53. new_loss = loss_class(activated=activated)(pred, target)
  54. assert isinstance(new_loss, torch.Tensor)
  55. assert new_loss == original_loss
  56. @pytest.mark.parametrize('loss_class', [
  57. IoULoss, BoundedIoULoss, GIoULoss, DIoULoss, CIoULoss, EIoULoss, MSELoss,
  58. L1Loss, SmoothL1Loss, BalancedL1Loss
  59. ])
  60. @pytest.mark.parametrize('input_shape', [(10, 4), (0, 4)])
  61. def test_regression_losses(loss_class, input_shape):
  62. pred = torch.rand(input_shape)
  63. target = torch.rand(input_shape)
  64. weight = torch.rand(input_shape)
  65. # Test loss forward
  66. loss = loss_class()(pred, target)
  67. assert isinstance(loss, torch.Tensor)
  68. # Test loss forward with weight
  69. loss = loss_class()(pred, target, weight)
  70. assert isinstance(loss, torch.Tensor)
  71. # Test loss forward with reduction_override
  72. loss = loss_class()(pred, target, reduction_override='mean')
  73. assert isinstance(loss, torch.Tensor)
  74. # Test loss forward with avg_factor
  75. loss = loss_class()(pred, target, avg_factor=10)
  76. assert isinstance(loss, torch.Tensor)
  77. with pytest.raises(ValueError):
  78. # loss can evaluate with avg_factor only if
  79. # reduction is None, 'none' or 'mean'.
  80. reduction_override = 'sum'
  81. loss_class()(
  82. pred, target, avg_factor=10, reduction_override=reduction_override)
  83. # Test loss forward with avg_factor and reduction
  84. for reduction_override in [None, 'none', 'mean']:
  85. loss_class()(
  86. pred, target, avg_factor=10, reduction_override=reduction_override)
  87. assert isinstance(loss, torch.Tensor)
  88. @pytest.mark.parametrize('loss_class', [CrossEntropyLoss])
  89. @pytest.mark.parametrize('input_shape', [(10, 5), (0, 5)])
  90. def test_classification_losses(loss_class, input_shape):
  91. if input_shape[0] == 0 and digit_version(
  92. torch.__version__) < digit_version('1.5.0'):
  93. pytest.skip(
  94. f'CELoss in PyTorch {torch.__version__} does not support empty'
  95. f'tensor.')
  96. pred = torch.rand(input_shape)
  97. target = torch.randint(0, 5, (input_shape[0], ))
  98. # Test loss forward
  99. loss = loss_class()(pred, target)
  100. assert isinstance(loss, torch.Tensor)
  101. # Test loss forward with reduction_override
  102. loss = loss_class()(pred, target, reduction_override='mean')
  103. assert isinstance(loss, torch.Tensor)
  104. # Test loss forward with avg_factor
  105. loss = loss_class()(pred, target, avg_factor=10)
  106. assert isinstance(loss, torch.Tensor)
  107. with pytest.raises(ValueError):
  108. # loss can evaluate with avg_factor only if
  109. # reduction is None, 'none' or 'mean'.
  110. reduction_override = 'sum'
  111. loss_class()(
  112. pred, target, avg_factor=10, reduction_override=reduction_override)
  113. # Test loss forward with avg_factor and reduction
  114. for reduction_override in [None, 'none', 'mean']:
  115. loss_class()(
  116. pred, target, avg_factor=10, reduction_override=reduction_override)
  117. assert isinstance(loss, torch.Tensor)
  118. @pytest.mark.parametrize('loss_class', [FocalLoss])
  119. @pytest.mark.parametrize('input_shape', [(10, 5), (3, 5, 40, 40)])
  120. def test_FocalLoss_loss(loss_class, input_shape):
  121. pred = torch.rand(input_shape)
  122. target = torch.randint(0, 5, (input_shape[0], ))
  123. if len(input_shape) == 4:
  124. B, N, W, H = input_shape
  125. target = F.one_hot(torch.randint(0, 5, (B * W * H, )),
  126. 5).reshape(B, W, H, N).permute(0, 3, 1, 2)
  127. # Test loss forward
  128. loss = loss_class()(pred, target)
  129. assert isinstance(loss, torch.Tensor)
  130. # Test loss forward with reduction_override
  131. loss = loss_class()(pred, target, reduction_override='mean')
  132. assert isinstance(loss, torch.Tensor)
  133. # Test loss forward with avg_factor
  134. loss = loss_class()(pred, target, avg_factor=10)
  135. assert isinstance(loss, torch.Tensor)
  136. with pytest.raises(ValueError):
  137. # loss can evaluate with avg_factor only if
  138. # reduction is None, 'none' or 'mean'.
  139. reduction_override = 'sum'
  140. loss_class()(
  141. pred, target, avg_factor=10, reduction_override=reduction_override)
  142. # Test loss forward with avg_factor and reduction
  143. for reduction_override in [None, 'none', 'mean']:
  144. loss_class()(
  145. pred, target, avg_factor=10, reduction_override=reduction_override)
  146. assert isinstance(loss, torch.Tensor)
  147. @pytest.mark.parametrize('loss_class', [GHMR])
  148. @pytest.mark.parametrize('input_shape', [(10, 4), (0, 4)])
  149. def test_GHMR_loss(loss_class, input_shape):
  150. pred = torch.rand(input_shape)
  151. target = torch.rand(input_shape)
  152. weight = torch.rand(input_shape)
  153. # Test loss forward
  154. loss = loss_class()(pred, target, weight)
  155. assert isinstance(loss, torch.Tensor)
  156. @pytest.mark.parametrize('use_sigmoid', [True, False])
  157. @pytest.mark.parametrize('reduction', ['sum', 'mean', None])
  158. @pytest.mark.parametrize('avg_non_ignore', [True, False])
  159. def test_loss_with_ignore_index(use_sigmoid, reduction, avg_non_ignore):
  160. # Test cross_entropy loss
  161. loss_class = CrossEntropyLoss(
  162. use_sigmoid=use_sigmoid,
  163. use_mask=False,
  164. ignore_index=255,
  165. avg_non_ignore=avg_non_ignore)
  166. pred = torch.rand((10, 5))
  167. target = torch.randint(0, 5, (10, ))
  168. ignored_indices = torch.randint(0, 10, (2, ), dtype=torch.long)
  169. target[ignored_indices] = 255
  170. # Test loss forward with default ignore
  171. loss_with_ignore = loss_class(pred, target, reduction_override=reduction)
  172. assert isinstance(loss_with_ignore, torch.Tensor)
  173. # Test loss forward with forward ignore
  174. target[ignored_indices] = 255
  175. loss_with_forward_ignore = loss_class(
  176. pred, target, ignore_index=255, reduction_override=reduction)
  177. assert isinstance(loss_with_forward_ignore, torch.Tensor)
  178. # Verify correctness
  179. if avg_non_ignore:
  180. # manually remove the ignored elements
  181. not_ignored_indices = (target != 255)
  182. pred = pred[not_ignored_indices]
  183. target = target[not_ignored_indices]
  184. loss = loss_class(pred, target, reduction_override=reduction)
  185. assert torch.allclose(loss, loss_with_ignore)
  186. assert torch.allclose(loss, loss_with_forward_ignore)
  187. # test ignore all target
  188. pred = torch.rand((10, 5))
  189. target = torch.ones((10, ), dtype=torch.long) * 255
  190. loss = loss_class(pred, target, reduction_override=reduction)
  191. assert loss == 0
  192. @pytest.mark.parametrize('naive_dice', [True, False])
  193. def test_dice_loss(naive_dice):
  194. loss_class = DiceLoss
  195. pred = torch.rand((10, 4, 4))
  196. target = torch.rand((10, 4, 4))
  197. weight = torch.rand((10))
  198. # Test loss forward
  199. loss = loss_class(naive_dice=naive_dice)(pred, target)
  200. assert isinstance(loss, torch.Tensor)
  201. # Test loss forward with weight
  202. loss = loss_class(naive_dice=naive_dice)(pred, target, weight)
  203. assert isinstance(loss, torch.Tensor)
  204. # Test loss forward with reduction_override
  205. loss = loss_class(naive_dice=naive_dice)(
  206. pred, target, reduction_override='mean')
  207. assert isinstance(loss, torch.Tensor)
  208. # Test loss forward with avg_factor
  209. loss = loss_class(naive_dice=naive_dice)(pred, target, avg_factor=10)
  210. assert isinstance(loss, torch.Tensor)
  211. with pytest.raises(ValueError):
  212. # loss can evaluate with avg_factor only if
  213. # reduction is None, 'none' or 'mean'.
  214. reduction_override = 'sum'
  215. loss_class(naive_dice=naive_dice)(
  216. pred, target, avg_factor=10, reduction_override=reduction_override)
  217. # Test loss forward with avg_factor and reduction
  218. for reduction_override in [None, 'none', 'mean']:
  219. loss_class(naive_dice=naive_dice)(
  220. pred, target, avg_factor=10, reduction_override=reduction_override)
  221. assert isinstance(loss, torch.Tensor)
  222. # Test loss forward with has_acted=False and use_sigmoid=False
  223. with pytest.raises(NotImplementedError):
  224. loss_class(
  225. use_sigmoid=False, activate=True, naive_dice=naive_dice)(pred,
  226. target)
  227. # Test loss forward with weight.ndim != loss.ndim
  228. with pytest.raises(AssertionError):
  229. weight = torch.rand((2, 8))
  230. loss_class(naive_dice=naive_dice)(pred, target, weight)
  231. # Test loss forward with len(weight) != len(pred)
  232. with pytest.raises(AssertionError):
  233. weight = torch.rand((8))
  234. loss_class(naive_dice=naive_dice)(pred, target, weight)