test_ema.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import itertools
  3. import math
  4. from unittest import TestCase
  5. import torch
  6. import torch.nn as nn
  7. from mmengine.testing import assert_allclose
  8. from mmdet.models.layers import ExpMomentumEMA
  9. class TestEMA(TestCase):
  10. def test_exp_momentum_ema(self):
  11. model = nn.Sequential(nn.Conv2d(1, 5, kernel_size=3), nn.Linear(5, 10))
  12. # Test invalid gamma
  13. with self.assertRaisesRegex(AssertionError,
  14. 'gamma must be greater than 0'):
  15. ExpMomentumEMA(model, gamma=-1)
  16. # Test EMA
  17. model = torch.nn.Sequential(
  18. torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
  19. momentum = 0.1
  20. gamma = 4
  21. ema_model = ExpMomentumEMA(model, momentum=momentum, gamma=gamma)
  22. averaged_params = [
  23. torch.zeros_like(param) for param in model.parameters()
  24. ]
  25. n_updates = 10
  26. for i in range(n_updates):
  27. updated_averaged_params = []
  28. for p, p_avg in zip(model.parameters(), averaged_params):
  29. p.detach().add_(torch.randn_like(p))
  30. if i == 0:
  31. updated_averaged_params.append(p.clone())
  32. else:
  33. m = (1 - momentum) * math.exp(-(1 + i) / gamma) + momentum
  34. updated_averaged_params.append(
  35. (p_avg * (1 - m) + p * m).clone())
  36. ema_model.update_parameters(model)
  37. averaged_params = updated_averaged_params
  38. for p_target, p_ema in zip(averaged_params, ema_model.parameters()):
  39. assert_allclose(p_target, p_ema)
  40. def test_exp_momentum_ema_update_buffer(self):
  41. model = nn.Sequential(
  42. nn.Conv2d(1, 5, kernel_size=3), nn.BatchNorm2d(5, momentum=0.3),
  43. nn.Linear(5, 10))
  44. # Test invalid gamma
  45. with self.assertRaisesRegex(AssertionError,
  46. 'gamma must be greater than 0'):
  47. ExpMomentumEMA(model, gamma=-1)
  48. # Test EMA with momentum annealing.
  49. momentum = 0.1
  50. gamma = 4
  51. ema_model = ExpMomentumEMA(
  52. model, gamma=gamma, momentum=momentum, update_buffers=True)
  53. averaged_params = [
  54. torch.zeros_like(param)
  55. for param in itertools.chain(model.parameters(), model.buffers())
  56. if param.size() != torch.Size([])
  57. ]
  58. n_updates = 10
  59. for i in range(n_updates):
  60. updated_averaged_params = []
  61. params = [
  62. param for param in itertools.chain(model.parameters(),
  63. model.buffers())
  64. if param.size() != torch.Size([])
  65. ]
  66. for p, p_avg in zip(params, averaged_params):
  67. p.detach().add_(torch.randn_like(p))
  68. if i == 0:
  69. updated_averaged_params.append(p.clone())
  70. else:
  71. m = (1 - momentum) * math.exp(-(1 + i) / gamma) + momentum
  72. updated_averaged_params.append(
  73. (p_avg * (1 - m) + p * m).clone())
  74. ema_model.update_parameters(model)
  75. averaged_params = updated_averaged_params
  76. ema_params = [
  77. param for param in itertools.chain(ema_model.module.parameters(),
  78. ema_model.module.buffers())
  79. if param.size() != torch.Size([])
  80. ]
  81. for p_target, p_ema in zip(averaged_params, ema_params):
  82. assert_allclose(p_target, p_ema)