123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch.nn as nn
- from mmcv.cnn import ConvModule, is_norm
- from mmengine.model import caffe2_xavier_init, constant_init, normal_init
- from torch.nn import BatchNorm2d
- from mmdet.registry import MODELS
- class Bottleneck(nn.Module):
- """Bottleneck block for DilatedEncoder used in `YOLOF.
- <https://arxiv.org/abs/2103.09460>`.
- The Bottleneck contains three ConvLayers and one residual connection.
- Args:
- in_channels (int): The number of input channels.
- mid_channels (int): The number of middle output channels.
- dilation (int): Dilation rate.
- norm_cfg (dict): Dictionary to construct and config norm layer.
- """
- def __init__(self,
- in_channels,
- mid_channels,
- dilation,
- norm_cfg=dict(type='BN', requires_grad=True)):
- super(Bottleneck, self).__init__()
- self.conv1 = ConvModule(
- in_channels, mid_channels, 1, norm_cfg=norm_cfg)
- self.conv2 = ConvModule(
- mid_channels,
- mid_channels,
- 3,
- padding=dilation,
- dilation=dilation,
- norm_cfg=norm_cfg)
- self.conv3 = ConvModule(
- mid_channels, in_channels, 1, norm_cfg=norm_cfg)
- def forward(self, x):
- identity = x
- out = self.conv1(x)
- out = self.conv2(out)
- out = self.conv3(out)
- out = out + identity
- return out
- @MODELS.register_module()
- class DilatedEncoder(nn.Module):
- """Dilated Encoder for YOLOF <https://arxiv.org/abs/2103.09460>`.
- This module contains two types of components:
- - the original FPN lateral convolution layer and fpn convolution layer,
- which are 1x1 conv + 3x3 conv
- - the dilated residual block
- Args:
- in_channels (int): The number of input channels.
- out_channels (int): The number of output channels.
- block_mid_channels (int): The number of middle block output channels
- num_residual_blocks (int): The number of residual blocks.
- block_dilations (list): The list of residual blocks dilation.
- """
- def __init__(self, in_channels, out_channels, block_mid_channels,
- num_residual_blocks, block_dilations):
- super(DilatedEncoder, self).__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.block_mid_channels = block_mid_channels
- self.num_residual_blocks = num_residual_blocks
- self.block_dilations = block_dilations
- self._init_layers()
- def _init_layers(self):
- self.lateral_conv = nn.Conv2d(
- self.in_channels, self.out_channels, kernel_size=1)
- self.lateral_norm = BatchNorm2d(self.out_channels)
- self.fpn_conv = nn.Conv2d(
- self.out_channels, self.out_channels, kernel_size=3, padding=1)
- self.fpn_norm = BatchNorm2d(self.out_channels)
- encoder_blocks = []
- for i in range(self.num_residual_blocks):
- dilation = self.block_dilations[i]
- encoder_blocks.append(
- Bottleneck(
- self.out_channels,
- self.block_mid_channels,
- dilation=dilation))
- self.dilated_encoder_blocks = nn.Sequential(*encoder_blocks)
- def init_weights(self):
- caffe2_xavier_init(self.lateral_conv)
- caffe2_xavier_init(self.fpn_conv)
- for m in [self.lateral_norm, self.fpn_norm]:
- constant_init(m, 1)
- for m in self.dilated_encoder_blocks.modules():
- if isinstance(m, nn.Conv2d):
- normal_init(m, mean=0, std=0.01)
- if is_norm(m):
- constant_init(m, 1)
- def forward(self, feature):
- out = self.lateral_norm(self.lateral_conv(feature[-1]))
- out = self.fpn_norm(self.fpn_conv(out))
- return self.dilated_encoder_blocks(out),
|