utils.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import math
  2. from typing import Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.cnn.bricks import Swish, build_norm_layer
  6. from torch.nn import functional as F
  7. from torch.nn.init import _calculate_fan_in_and_fan_out, trunc_normal_
  8. from mmdet.registry import MODELS
  9. from mmdet.utils import OptConfigType
  10. def variance_scaling_trunc(tensor, gain=1.):
  11. fan_in, _ = _calculate_fan_in_and_fan_out(tensor)
  12. gain /= max(1.0, fan_in)
  13. std = math.sqrt(gain) / .87962566103423978
  14. return trunc_normal_(tensor, 0., std)
  15. @MODELS.register_module()
  16. class Conv2dSamePadding(nn.Conv2d):
  17. def __init__(self,
  18. in_channels: int,
  19. out_channels: int,
  20. kernel_size: Union[int, Tuple[int, int]],
  21. stride: Union[int, Tuple[int, int]] = 1,
  22. padding: Union[int, Tuple[int, int]] = 0,
  23. dilation: Union[int, Tuple[int, int]] = 1,
  24. groups: int = 1,
  25. bias: bool = True):
  26. super().__init__(in_channels, out_channels, kernel_size, stride, 0,
  27. dilation, groups, bias)
  28. def forward(self, x: torch.Tensor) -> torch.Tensor:
  29. img_h, img_w = x.size()[-2:]
  30. kernel_h, kernel_w = self.weight.size()[-2:]
  31. extra_w = (math.ceil(img_w / self.stride[1]) -
  32. 1) * self.stride[1] - img_w + kernel_w
  33. extra_h = (math.ceil(img_h / self.stride[0]) -
  34. 1) * self.stride[0] - img_h + kernel_h
  35. left = extra_w // 2
  36. right = extra_w - left
  37. top = extra_h // 2
  38. bottom = extra_h - top
  39. x = F.pad(x, [left, right, top, bottom])
  40. return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
  41. self.dilation, self.groups)
  42. class MaxPool2dSamePadding(nn.Module):
  43. def __init__(self,
  44. kernel_size: Union[int, Tuple[int, int]] = 3,
  45. stride: Union[int, Tuple[int, int]] = 2,
  46. **kwargs):
  47. super().__init__()
  48. self.pool = nn.MaxPool2d(kernel_size, stride, **kwargs)
  49. self.stride = self.pool.stride
  50. self.kernel_size = self.pool.kernel_size
  51. if isinstance(self.stride, int):
  52. self.stride = [self.stride] * 2
  53. if isinstance(self.kernel_size, int):
  54. self.kernel_size = [self.kernel_size] * 2
  55. def forward(self, x):
  56. h, w = x.shape[-2:]
  57. extra_h = (math.ceil(w / self.stride[1]) -
  58. 1) * self.stride[1] - w + self.kernel_size[1]
  59. extra_v = (math.ceil(h / self.stride[0]) -
  60. 1) * self.stride[0] - h + self.kernel_size[0]
  61. left = extra_h // 2
  62. right = extra_h - left
  63. top = extra_v // 2
  64. bottom = extra_v - top
  65. x = F.pad(x, [left, right, top, bottom])
  66. x = self.pool(x)
  67. return x
  68. class DepthWiseConvBlock(nn.Module):
  69. def __init__(
  70. self,
  71. in_channels: int,
  72. out_channels: int,
  73. apply_norm: bool = True,
  74. conv_bn_act_pattern: bool = False,
  75. norm_cfg: OptConfigType = dict(type='BN', momentum=1e-2, eps=1e-3)
  76. ) -> None:
  77. super(DepthWiseConvBlock, self).__init__()
  78. self.depthwise_conv = Conv2dSamePadding(
  79. in_channels,
  80. in_channels,
  81. kernel_size=3,
  82. stride=1,
  83. groups=in_channels,
  84. bias=False)
  85. self.pointwise_conv = Conv2dSamePadding(
  86. in_channels, out_channels, kernel_size=1, stride=1)
  87. self.apply_norm = apply_norm
  88. if self.apply_norm:
  89. self.bn = build_norm_layer(norm_cfg, num_features=out_channels)[1]
  90. self.apply_activation = conv_bn_act_pattern
  91. if self.apply_activation:
  92. self.swish = Swish()
  93. def forward(self, x):
  94. x = self.depthwise_conv(x)
  95. x = self.pointwise_conv(x)
  96. if self.apply_norm:
  97. x = self.bn(x)
  98. if self.apply_activation:
  99. x = self.swish(x)
  100. return x
  101. class DownChannelBlock(nn.Module):
  102. def __init__(
  103. self,
  104. in_channels: int,
  105. out_channels: int,
  106. apply_norm: bool = True,
  107. conv_bn_act_pattern: bool = False,
  108. norm_cfg: OptConfigType = dict(type='BN', momentum=1e-2, eps=1e-3)
  109. ) -> None:
  110. super(DownChannelBlock, self).__init__()
  111. self.down_conv = Conv2dSamePadding(in_channels, out_channels, 1)
  112. self.apply_norm = apply_norm
  113. if self.apply_norm:
  114. self.bn = build_norm_layer(norm_cfg, num_features=out_channels)[1]
  115. self.apply_activation = conv_bn_act_pattern
  116. if self.apply_activation:
  117. self.swish = Swish()
  118. def forward(self, x):
  119. x = self.down_conv(x)
  120. if self.apply_norm:
  121. x = self.bn(x)
  122. if self.apply_activation:
  123. x = self.swish(x)
  124. return x