test_quadratic_warmup.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. import torch.nn.functional as F
  5. import torch.optim as optim
  6. from mmengine.optim.scheduler import _ParamScheduler
  7. from mmengine.testing import assert_allclose
  8. from mmdet.engine.schedulers import (QuadraticWarmupLR,
  9. QuadraticWarmupMomentum,
  10. QuadraticWarmupParamScheduler)
  11. class ToyModel(torch.nn.Module):
  12. def __init__(self):
  13. super().__init__()
  14. self.conv1 = torch.nn.Conv2d(1, 1, 1)
  15. self.conv2 = torch.nn.Conv2d(1, 1, 1)
  16. def forward(self, x):
  17. return self.conv2(F.relu(self.conv1(x)))
  18. class TestQuadraticWarmupScheduler(TestCase):
  19. def setUp(self):
  20. """Setup the model and optimizer which are used in every test method.
  21. TestCase calls functions in this order: setUp() -> testMethod() ->
  22. tearDown() -> cleanUp()
  23. """
  24. self.model = ToyModel()
  25. self.optimizer = optim.SGD(
  26. self.model.parameters(), lr=0.05, momentum=0.01, weight_decay=5e-4)
  27. def _test_scheduler_value(self,
  28. schedulers,
  29. targets,
  30. epochs=10,
  31. param_name='lr'):
  32. if isinstance(schedulers, _ParamScheduler):
  33. schedulers = [schedulers]
  34. for epoch in range(epochs):
  35. for param_group, target in zip(self.optimizer.param_groups,
  36. targets):
  37. print(param_group[param_name])
  38. assert_allclose(
  39. target[epoch],
  40. param_group[param_name],
  41. msg='{} is wrong in epoch {}: expected {}, got {}'.format(
  42. param_name, epoch, target[epoch],
  43. param_group[param_name]),
  44. atol=1e-5,
  45. rtol=0)
  46. [scheduler.step() for scheduler in schedulers]
  47. def test_quadratic_warmup_scheduler(self):
  48. with self.assertRaises(ValueError):
  49. QuadraticWarmupParamScheduler(self.optimizer, param_name='lr')
  50. epochs = 10
  51. iters = 5
  52. warmup_factor = [pow((i + 1) / float(iters), 2) for i in range(iters)]
  53. single_targets = [x * 0.05 for x in warmup_factor] + [0.05] * (
  54. epochs - iters)
  55. targets = [single_targets, [x * epochs for x in single_targets]]
  56. scheduler = QuadraticWarmupParamScheduler(
  57. self.optimizer, param_name='lr', end=iters)
  58. self._test_scheduler_value(scheduler, targets, epochs)
  59. def test_quadratic_warmup_scheduler_convert_iterbased(self):
  60. epochs = 10
  61. end = 5
  62. epoch_length = 11
  63. iters = end * epoch_length
  64. warmup_factor = [pow((i + 1) / float(iters), 2) for i in range(iters)]
  65. single_targets = [x * 0.05 for x in warmup_factor] + [0.05] * (
  66. epochs * epoch_length - iters)
  67. targets = [single_targets, [x * epochs for x in single_targets]]
  68. scheduler = QuadraticWarmupParamScheduler.build_iter_from_epoch(
  69. self.optimizer,
  70. param_name='lr',
  71. end=end,
  72. epoch_length=epoch_length)
  73. self._test_scheduler_value(scheduler, targets, epochs * epoch_length)
  74. def test_quadratic_warmup_lr(self):
  75. epochs = 10
  76. iters = 5
  77. warmup_factor = [pow((i + 1) / float(iters), 2) for i in range(iters)]
  78. single_targets = [x * 0.05 for x in warmup_factor] + [0.05] * (
  79. epochs - iters)
  80. targets = [single_targets, [x * epochs for x in single_targets]]
  81. scheduler = QuadraticWarmupLR(self.optimizer, end=iters)
  82. self._test_scheduler_value(scheduler, targets, epochs)
  83. def test_quadratic_warmup_momentum(self):
  84. epochs = 10
  85. iters = 5
  86. warmup_factor = [pow((i + 1) / float(iters), 2) for i in range(iters)]
  87. single_targets = [x * 0.01 for x in warmup_factor] + [0.01] * (
  88. epochs - iters)
  89. targets = [single_targets, [x * epochs for x in single_targets]]
  90. scheduler = QuadraticWarmupMomentum(self.optimizer, end=iters)
  91. self._test_scheduler_value(
  92. scheduler, targets, epochs, param_name='momentum')