123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154 |
- import math
- from typing import Tuple, Union
- import torch
- import torch.nn as nn
- from mmcv.cnn.bricks import Swish, build_norm_layer
- from torch.nn import functional as F
- from torch.nn.init import _calculate_fan_in_and_fan_out, trunc_normal_
- from mmdet.registry import MODELS
- from mmdet.utils import OptConfigType
- def variance_scaling_trunc(tensor, gain=1.):
- fan_in, _ = _calculate_fan_in_and_fan_out(tensor)
- gain /= max(1.0, fan_in)
- std = math.sqrt(gain) / .87962566103423978
- return trunc_normal_(tensor, 0., std)
- @MODELS.register_module()
- class Conv2dSamePadding(nn.Conv2d):
- def __init__(self,
- in_channels: int,
- out_channels: int,
- kernel_size: Union[int, Tuple[int, int]],
- stride: Union[int, Tuple[int, int]] = 1,
- padding: Union[int, Tuple[int, int]] = 0,
- dilation: Union[int, Tuple[int, int]] = 1,
- groups: int = 1,
- bias: bool = True):
- super().__init__(in_channels, out_channels, kernel_size, stride, 0,
- dilation, groups, bias)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- img_h, img_w = x.size()[-2:]
- kernel_h, kernel_w = self.weight.size()[-2:]
- extra_w = (math.ceil(img_w / self.stride[1]) -
- 1) * self.stride[1] - img_w + kernel_w
- extra_h = (math.ceil(img_h / self.stride[0]) -
- 1) * self.stride[0] - img_h + kernel_h
- left = extra_w // 2
- right = extra_w - left
- top = extra_h // 2
- bottom = extra_h - top
- x = F.pad(x, [left, right, top, bottom])
- return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
- self.dilation, self.groups)
- class MaxPool2dSamePadding(nn.Module):
- def __init__(self,
- kernel_size: Union[int, Tuple[int, int]] = 3,
- stride: Union[int, Tuple[int, int]] = 2,
- **kwargs):
- super().__init__()
- self.pool = nn.MaxPool2d(kernel_size, stride, **kwargs)
- self.stride = self.pool.stride
- self.kernel_size = self.pool.kernel_size
- if isinstance(self.stride, int):
- self.stride = [self.stride] * 2
- if isinstance(self.kernel_size, int):
- self.kernel_size = [self.kernel_size] * 2
- def forward(self, x):
- h, w = x.shape[-2:]
- extra_h = (math.ceil(w / self.stride[1]) -
- 1) * self.stride[1] - w + self.kernel_size[1]
- extra_v = (math.ceil(h / self.stride[0]) -
- 1) * self.stride[0] - h + self.kernel_size[0]
- left = extra_h // 2
- right = extra_h - left
- top = extra_v // 2
- bottom = extra_v - top
- x = F.pad(x, [left, right, top, bottom])
- x = self.pool(x)
- return x
- class DepthWiseConvBlock(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- apply_norm: bool = True,
- conv_bn_act_pattern: bool = False,
- norm_cfg: OptConfigType = dict(type='BN', momentum=1e-2, eps=1e-3)
- ) -> None:
- super(DepthWiseConvBlock, self).__init__()
- self.depthwise_conv = Conv2dSamePadding(
- in_channels,
- in_channels,
- kernel_size=3,
- stride=1,
- groups=in_channels,
- bias=False)
- self.pointwise_conv = Conv2dSamePadding(
- in_channels, out_channels, kernel_size=1, stride=1)
- self.apply_norm = apply_norm
- if self.apply_norm:
- self.bn = build_norm_layer(norm_cfg, num_features=out_channels)[1]
- self.apply_activation = conv_bn_act_pattern
- if self.apply_activation:
- self.swish = Swish()
- def forward(self, x):
- x = self.depthwise_conv(x)
- x = self.pointwise_conv(x)
- if self.apply_norm:
- x = self.bn(x)
- if self.apply_activation:
- x = self.swish(x)
- return x
- class DownChannelBlock(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- apply_norm: bool = True,
- conv_bn_act_pattern: bool = False,
- norm_cfg: OptConfigType = dict(type='BN', momentum=1e-2, eps=1e-3)
- ) -> None:
- super(DownChannelBlock, self).__init__()
- self.down_conv = Conv2dSamePadding(in_channels, out_channels, 1)
- self.apply_norm = apply_norm
- if self.apply_norm:
- self.bn = build_norm_layer(norm_cfg, num_features=out_channels)[1]
- self.apply_activation = conv_bn_act_pattern
- if self.apply_activation:
- self.swish = Swish()
- def forward(self, x):
- x = self.down_conv(x)
- if self.apply_norm:
- x = self.bn(x)
- if self.apply_activation:
- x = self.swish(x)
- return x
|