123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.utils.checkpoint as cp
- from mmcv.cnn import build_conv_layer, build_norm_layer
- from mmengine.model import BaseModule
- from torch.nn.modules.utils import _pair
- from mmdet.models.backbones.resnet import Bottleneck, ResNet
- from mmdet.registry import MODELS
- class TridentConv(BaseModule):
- """Trident Convolution Module.
- Args:
- in_channels (int): Number of channels in input.
- out_channels (int): Number of channels in output.
- kernel_size (int): Size of convolution kernel.
- stride (int, optional): Convolution stride. Default: 1.
- trident_dilations (tuple[int, int, int], optional): Dilations of
- different trident branch. Default: (1, 2, 3).
- test_branch_idx (int, optional): In inference, all 3 branches will
- be used if `test_branch_idx==-1`, otherwise only branch with
- index `test_branch_idx` will be used. Default: 1.
- bias (bool, optional): Whether to use bias in convolution or not.
- Default: False.
- init_cfg (dict or list[dict], optional): Initialization config dict.
- Default: None
- """
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- trident_dilations=(1, 2, 3),
- test_branch_idx=1,
- bias=False,
- init_cfg=None):
- super(TridentConv, self).__init__(init_cfg)
- self.num_branch = len(trident_dilations)
- self.with_bias = bias
- self.test_branch_idx = test_branch_idx
- self.stride = _pair(stride)
- self.kernel_size = _pair(kernel_size)
- self.paddings = _pair(trident_dilations)
- self.dilations = trident_dilations
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.bias = bias
- self.weight = nn.Parameter(
- torch.Tensor(out_channels, in_channels, *self.kernel_size))
- if bias:
- self.bias = nn.Parameter(torch.Tensor(out_channels))
- else:
- self.bias = None
- def extra_repr(self):
- tmpstr = f'in_channels={self.in_channels}'
- tmpstr += f', out_channels={self.out_channels}'
- tmpstr += f', kernel_size={self.kernel_size}'
- tmpstr += f', num_branch={self.num_branch}'
- tmpstr += f', test_branch_idx={self.test_branch_idx}'
- tmpstr += f', stride={self.stride}'
- tmpstr += f', paddings={self.paddings}'
- tmpstr += f', dilations={self.dilations}'
- tmpstr += f', bias={self.bias}'
- return tmpstr
- def forward(self, inputs):
- if self.training or self.test_branch_idx == -1:
- outputs = [
- F.conv2d(input, self.weight, self.bias, self.stride, padding,
- dilation) for input, dilation, padding in zip(
- inputs, self.dilations, self.paddings)
- ]
- else:
- assert len(inputs) == 1
- outputs = [
- F.conv2d(inputs[0], self.weight, self.bias, self.stride,
- self.paddings[self.test_branch_idx],
- self.dilations[self.test_branch_idx])
- ]
- return outputs
- # Since TridentNet is defined over ResNet50 and ResNet101, here we
- # only support TridentBottleneckBlock.
- class TridentBottleneck(Bottleneck):
- """BottleBlock for TridentResNet.
- Args:
- trident_dilations (tuple[int, int, int]): Dilations of different
- trident branch.
- test_branch_idx (int): In inference, all 3 branches will be used
- if `test_branch_idx==-1`, otherwise only branch with index
- `test_branch_idx` will be used.
- concat_output (bool): Whether to concat the output list to a Tensor.
- `True` only in the last Block.
- """
- def __init__(self, trident_dilations, test_branch_idx, concat_output,
- **kwargs):
- super(TridentBottleneck, self).__init__(**kwargs)
- self.trident_dilations = trident_dilations
- self.num_branch = len(trident_dilations)
- self.concat_output = concat_output
- self.test_branch_idx = test_branch_idx
- self.conv2 = TridentConv(
- self.planes,
- self.planes,
- kernel_size=3,
- stride=self.conv2_stride,
- bias=False,
- trident_dilations=self.trident_dilations,
- test_branch_idx=test_branch_idx,
- init_cfg=dict(
- type='Kaiming',
- distribution='uniform',
- mode='fan_in',
- override=dict(name='conv2')))
- def forward(self, x):
- def _inner_forward(x):
- num_branch = (
- self.num_branch
- if self.training or self.test_branch_idx == -1 else 1)
- identity = x
- if not isinstance(x, list):
- x = (x, ) * num_branch
- identity = x
- if self.downsample is not None:
- identity = [self.downsample(b) for b in x]
- out = [self.conv1(b) for b in x]
- out = [self.norm1(b) for b in out]
- out = [self.relu(b) for b in out]
- if self.with_plugins:
- for k in range(len(out)):
- out[k] = self.forward_plugin(out[k],
- self.after_conv1_plugin_names)
- out = self.conv2(out)
- out = [self.norm2(b) for b in out]
- out = [self.relu(b) for b in out]
- if self.with_plugins:
- for k in range(len(out)):
- out[k] = self.forward_plugin(out[k],
- self.after_conv2_plugin_names)
- out = [self.conv3(b) for b in out]
- out = [self.norm3(b) for b in out]
- if self.with_plugins:
- for k in range(len(out)):
- out[k] = self.forward_plugin(out[k],
- self.after_conv3_plugin_names)
- out = [
- out_b + identity_b for out_b, identity_b in zip(out, identity)
- ]
- return out
- if self.with_cp and x.requires_grad:
- out = cp.checkpoint(_inner_forward, x)
- else:
- out = _inner_forward(x)
- out = [self.relu(b) for b in out]
- if self.concat_output:
- out = torch.cat(out, dim=0)
- return out
- def make_trident_res_layer(block,
- inplanes,
- planes,
- num_blocks,
- stride=1,
- trident_dilations=(1, 2, 3),
- style='pytorch',
- with_cp=False,
- conv_cfg=None,
- norm_cfg=dict(type='BN'),
- dcn=None,
- plugins=None,
- test_branch_idx=-1):
- """Build Trident Res Layers."""
- downsample = None
- if stride != 1 or inplanes != planes * block.expansion:
- downsample = []
- conv_stride = stride
- downsample.extend([
- build_conv_layer(
- conv_cfg,
- inplanes,
- planes * block.expansion,
- kernel_size=1,
- stride=conv_stride,
- bias=False),
- build_norm_layer(norm_cfg, planes * block.expansion)[1]
- ])
- downsample = nn.Sequential(*downsample)
- layers = []
- for i in range(num_blocks):
- layers.append(
- block(
- inplanes=inplanes,
- planes=planes,
- stride=stride if i == 0 else 1,
- trident_dilations=trident_dilations,
- downsample=downsample if i == 0 else None,
- style=style,
- with_cp=with_cp,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- dcn=dcn,
- plugins=plugins,
- test_branch_idx=test_branch_idx,
- concat_output=True if i == num_blocks - 1 else False))
- inplanes = planes * block.expansion
- return nn.Sequential(*layers)
- @MODELS.register_module()
- class TridentResNet(ResNet):
- """The stem layer, stage 1 and stage 2 in Trident ResNet are identical to
- ResNet, while in stage 3, Trident BottleBlock is utilized to replace the
- normal BottleBlock to yield trident output. Different branch shares the
- convolution weight but uses different dilations to achieve multi-scale
- output.
- / stage3(b0) \
- x - stem - stage1 - stage2 - stage3(b1) - output
- \ stage3(b2) /
- Args:
- depth (int): Depth of resnet, from {50, 101, 152}.
- num_branch (int): Number of branches in TridentNet.
- test_branch_idx (int): In inference, all 3 branches will be used
- if `test_branch_idx==-1`, otherwise only branch with index
- `test_branch_idx` will be used.
- trident_dilations (tuple[int]): Dilations of different trident branch.
- len(trident_dilations) should be equal to num_branch.
- """ # noqa
- def __init__(self, depth, num_branch, test_branch_idx, trident_dilations,
- **kwargs):
- assert num_branch == len(trident_dilations)
- assert depth in (50, 101, 152)
- super(TridentResNet, self).__init__(depth, **kwargs)
- assert self.num_stages == 3
- self.test_branch_idx = test_branch_idx
- self.num_branch = num_branch
- last_stage_idx = self.num_stages - 1
- stride = self.strides[last_stage_idx]
- dilation = trident_dilations
- dcn = self.dcn if self.stage_with_dcn[last_stage_idx] else None
- if self.plugins is not None:
- stage_plugins = self.make_stage_plugins(self.plugins,
- last_stage_idx)
- else:
- stage_plugins = None
- planes = self.base_channels * 2**last_stage_idx
- res_layer = make_trident_res_layer(
- TridentBottleneck,
- inplanes=(self.block.expansion * self.base_channels *
- 2**(last_stage_idx - 1)),
- planes=planes,
- num_blocks=self.stage_blocks[last_stage_idx],
- stride=stride,
- trident_dilations=dilation,
- style=self.style,
- with_cp=self.with_cp,
- conv_cfg=self.conv_cfg,
- norm_cfg=self.norm_cfg,
- dcn=dcn,
- plugins=stage_plugins,
- test_branch_idx=self.test_branch_idx)
- layer_name = f'layer{last_stage_idx + 1}'
- self.__setattr__(layer_name, res_layer)
- self.res_layers.pop(last_stage_idx)
- self.res_layers.insert(last_stage_idx, layer_name)
- self._freeze_stages()
|