normed_predictor.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch import Tensor
  6. from mmdet.registry import MODELS
  7. MODELS.register_module('Linear', module=nn.Linear)
  8. @MODELS.register_module(name='NormedLinear')
  9. class NormedLinear(nn.Linear):
  10. """Normalized Linear Layer.
  11. Args:
  12. tempeature (float, optional): Tempeature term. Defaults to 20.
  13. power (int, optional): Power term. Defaults to 1.0.
  14. eps (float, optional): The minimal value of divisor to
  15. keep numerical stability. Defaults to 1e-6.
  16. """
  17. def __init__(self,
  18. *args,
  19. tempearture: float = 20,
  20. power: int = 1.0,
  21. eps: float = 1e-6,
  22. **kwargs) -> None:
  23. super().__init__(*args, **kwargs)
  24. self.tempearture = tempearture
  25. self.power = power
  26. self.eps = eps
  27. self.init_weights()
  28. def init_weights(self) -> None:
  29. """Initialize the weights."""
  30. nn.init.normal_(self.weight, mean=0, std=0.01)
  31. if self.bias is not None:
  32. nn.init.constant_(self.bias, 0)
  33. def forward(self, x: Tensor) -> Tensor:
  34. """Forward function for `NormedLinear`."""
  35. weight_ = self.weight / (
  36. self.weight.norm(dim=1, keepdim=True).pow(self.power) + self.eps)
  37. x_ = x / (x.norm(dim=1, keepdim=True).pow(self.power) + self.eps)
  38. x_ = x_ * self.tempearture
  39. return F.linear(x_, weight_, self.bias)
  40. @MODELS.register_module(name='NormedConv2d')
  41. class NormedConv2d(nn.Conv2d):
  42. """Normalized Conv2d Layer.
  43. Args:
  44. tempeature (float, optional): Tempeature term. Defaults to 20.
  45. power (int, optional): Power term. Defaults to 1.0.
  46. eps (float, optional): The minimal value of divisor to
  47. keep numerical stability. Defaults to 1e-6.
  48. norm_over_kernel (bool, optional): Normalize over kernel.
  49. Defaults to False.
  50. """
  51. def __init__(self,
  52. *args,
  53. tempearture: float = 20,
  54. power: int = 1.0,
  55. eps: float = 1e-6,
  56. norm_over_kernel: bool = False,
  57. **kwargs) -> None:
  58. super().__init__(*args, **kwargs)
  59. self.tempearture = tempearture
  60. self.power = power
  61. self.norm_over_kernel = norm_over_kernel
  62. self.eps = eps
  63. def forward(self, x: Tensor) -> Tensor:
  64. """Forward function for `NormedConv2d`."""
  65. if not self.norm_over_kernel:
  66. weight_ = self.weight / (
  67. self.weight.norm(dim=1, keepdim=True).pow(self.power) +
  68. self.eps)
  69. else:
  70. weight_ = self.weight / (
  71. self.weight.view(self.weight.size(0), -1).norm(
  72. dim=1, keepdim=True).pow(self.power)[..., None, None] +
  73. self.eps)
  74. x_ = x / (x.norm(dim=1, keepdim=True).pow(self.power) + self.eps)
  75. x_ = x_ * self.tempearture
  76. if hasattr(self, 'conv2d_forward'):
  77. x_ = self.conv2d_forward(x_, weight_)
  78. else:
  79. if torch.__version__ >= '1.8':
  80. x_ = self._conv_forward(x_, weight_, self.bias)
  81. else:
  82. x_ = self._conv_forward(x_, weight_)
  83. return x_