123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch.nn as nn
- import torch.nn.functional as F
- from mmcv.cnn import ConvModule
- from mmengine.model import BaseModule
- from mmdet.registry import MODELS
- class Transition(BaseModule):
- """Base class for transition.
- Args:
- in_channels (int): Number of input channels.
- out_channels (int): Number of output channels.
- """
- def __init__(self, in_channels, out_channels, init_cfg=None):
- super().__init__(init_cfg)
- self.in_channels = in_channels
- self.out_channels = out_channels
- def forward(x):
- pass
- class UpInterpolationConv(Transition):
- """A transition used for up-sampling.
- Up-sample the input by interpolation then refines the feature by
- a convolution layer.
- Args:
- in_channels (int): Number of input channels.
- out_channels (int): Number of output channels.
- scale_factor (int): Up-sampling factor. Default: 2.
- mode (int): Interpolation mode. Default: nearest.
- align_corners (bool): Whether align corners when interpolation.
- Default: None.
- kernel_size (int): Kernel size for the conv. Default: 3.
- """
- def __init__(self,
- in_channels,
- out_channels,
- scale_factor=2,
- mode='nearest',
- align_corners=None,
- kernel_size=3,
- init_cfg=None,
- **kwargs):
- super().__init__(in_channels, out_channels, init_cfg)
- self.mode = mode
- self.scale_factor = scale_factor
- self.align_corners = align_corners
- self.conv = ConvModule(
- in_channels,
- out_channels,
- kernel_size,
- padding=(kernel_size - 1) // 2,
- **kwargs)
- def forward(self, x):
- x = F.interpolate(
- x,
- scale_factor=self.scale_factor,
- mode=self.mode,
- align_corners=self.align_corners)
- x = self.conv(x)
- return x
- class LastConv(Transition):
- """A transition used for refining the output of the last stage.
- Args:
- in_channels (int): Number of input channels.
- out_channels (int): Number of output channels.
- num_inputs (int): Number of inputs of the FPN features.
- kernel_size (int): Kernel size for the conv. Default: 3.
- """
- def __init__(self,
- in_channels,
- out_channels,
- num_inputs,
- kernel_size=3,
- init_cfg=None,
- **kwargs):
- super().__init__(in_channels, out_channels, init_cfg)
- self.num_inputs = num_inputs
- self.conv_out = ConvModule(
- in_channels,
- out_channels,
- kernel_size,
- padding=(kernel_size - 1) // 2,
- **kwargs)
- def forward(self, inputs):
- assert len(inputs) == self.num_inputs
- return self.conv_out(inputs[-1])
- @MODELS.register_module()
- class FPG(BaseModule):
- """FPG.
- Implementation of `Feature Pyramid Grids (FPG)
- <https://arxiv.org/abs/2004.03580>`_.
- This implementation only gives the basic structure stated in the paper.
- But users can implement different type of transitions to fully explore the
- the potential power of the structure of FPG.
- Args:
- in_channels (int): Number of input channels (feature maps of all levels
- should have the same channels).
- out_channels (int): Number of output channels (used at each scale)
- num_outs (int): Number of output scales.
- stack_times (int): The number of times the pyramid architecture will
- be stacked.
- paths (list[str]): Specify the path order of each stack level.
- Each element in the list should be either 'bu' (bottom-up) or
- 'td' (top-down).
- inter_channels (int): Number of inter channels.
- same_up_trans (dict): Transition that goes down at the same stage.
- same_down_trans (dict): Transition that goes up at the same stage.
- across_lateral_trans (dict): Across-pathway same-stage
- across_down_trans (dict): Across-pathway bottom-up connection.
- across_up_trans (dict): Across-pathway top-down connection.
- across_skip_trans (dict): Across-pathway skip connection.
- output_trans (dict): Transition that trans the output of the
- last stage.
- start_level (int): Index of the start input backbone level used to
- build the feature pyramid. Default: 0.
- end_level (int): Index of the end input backbone level (exclusive) to
- build the feature pyramid. Default: -1, which means the last level.
- add_extra_convs (bool): It decides whether to add conv
- layers on top of the original feature maps. Default to False.
- If True, its actual mode is specified by `extra_convs_on_inputs`.
- norm_cfg (dict): Config dict for normalization layer. Default: None.
- init_cfg (dict or list[dict], optional): Initialization config dict.
- """
- transition_types = {
- 'conv': ConvModule,
- 'interpolation_conv': UpInterpolationConv,
- 'last_conv': LastConv,
- }
- def __init__(self,
- in_channels,
- out_channels,
- num_outs,
- stack_times,
- paths,
- inter_channels=None,
- same_down_trans=None,
- same_up_trans=dict(
- type='conv', kernel_size=3, stride=2, padding=1),
- across_lateral_trans=dict(type='conv', kernel_size=1),
- across_down_trans=dict(type='conv', kernel_size=3),
- across_up_trans=None,
- across_skip_trans=dict(type='identity'),
- output_trans=dict(type='last_conv', kernel_size=3),
- start_level=0,
- end_level=-1,
- add_extra_convs=False,
- norm_cfg=None,
- skip_inds=None,
- init_cfg=[
- dict(type='Caffe2Xavier', layer='Conv2d'),
- dict(
- type='Constant',
- layer=[
- '_BatchNorm', '_InstanceNorm', 'GroupNorm',
- 'LayerNorm'
- ],
- val=1.0)
- ]):
- super(FPG, self).__init__(init_cfg)
- assert isinstance(in_channels, list)
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.num_ins = len(in_channels)
- self.num_outs = num_outs
- if inter_channels is None:
- self.inter_channels = [out_channels for _ in range(num_outs)]
- elif isinstance(inter_channels, int):
- self.inter_channels = [inter_channels for _ in range(num_outs)]
- else:
- assert isinstance(inter_channels, list)
- assert len(inter_channels) == num_outs
- self.inter_channels = inter_channels
- self.stack_times = stack_times
- self.paths = paths
- assert isinstance(paths, list) and len(paths) == stack_times
- for d in paths:
- assert d in ('bu', 'td')
- self.same_down_trans = same_down_trans
- self.same_up_trans = same_up_trans
- self.across_lateral_trans = across_lateral_trans
- self.across_down_trans = across_down_trans
- self.across_up_trans = across_up_trans
- self.output_trans = output_trans
- self.across_skip_trans = across_skip_trans
- self.with_bias = norm_cfg is None
- # skip inds must be specified if across skip trans is not None
- if self.across_skip_trans is not None:
- skip_inds is not None
- self.skip_inds = skip_inds
- assert len(self.skip_inds[0]) <= self.stack_times
- if end_level == -1 or end_level == self.num_ins - 1:
- self.backbone_end_level = self.num_ins
- assert num_outs >= self.num_ins - start_level
- else:
- # if end_level is not the last level, no extra level is allowed
- self.backbone_end_level = end_level + 1
- assert end_level < self.num_ins
- assert num_outs == end_level - start_level + 1
- self.start_level = start_level
- self.end_level = end_level
- self.add_extra_convs = add_extra_convs
- # build lateral 1x1 convs to reduce channels
- self.lateral_convs = nn.ModuleList()
- for i in range(self.start_level, self.backbone_end_level):
- l_conv = nn.Conv2d(self.in_channels[i],
- self.inter_channels[i - self.start_level], 1)
- self.lateral_convs.append(l_conv)
- extra_levels = num_outs - self.backbone_end_level + self.start_level
- self.extra_downsamples = nn.ModuleList()
- for i in range(extra_levels):
- if self.add_extra_convs:
- fpn_idx = self.backbone_end_level - self.start_level + i
- extra_conv = nn.Conv2d(
- self.inter_channels[fpn_idx - 1],
- self.inter_channels[fpn_idx],
- 3,
- stride=2,
- padding=1)
- self.extra_downsamples.append(extra_conv)
- else:
- self.extra_downsamples.append(nn.MaxPool2d(1, stride=2))
- self.fpn_transitions = nn.ModuleList() # stack times
- for s in range(self.stack_times):
- stage_trans = nn.ModuleList() # num of feature levels
- for i in range(self.num_outs):
- # same, across_lateral, across_down, across_up
- trans = nn.ModuleDict()
- if s in self.skip_inds[i]:
- stage_trans.append(trans)
- continue
- # build same-stage down trans (used in bottom-up paths)
- if i == 0 or self.same_up_trans is None:
- same_up_trans = None
- else:
- same_up_trans = self.build_trans(
- self.same_up_trans, self.inter_channels[i - 1],
- self.inter_channels[i])
- trans['same_up'] = same_up_trans
- # build same-stage up trans (used in top-down paths)
- if i == self.num_outs - 1 or self.same_down_trans is None:
- same_down_trans = None
- else:
- same_down_trans = self.build_trans(
- self.same_down_trans, self.inter_channels[i + 1],
- self.inter_channels[i])
- trans['same_down'] = same_down_trans
- # build across lateral trans
- across_lateral_trans = self.build_trans(
- self.across_lateral_trans, self.inter_channels[i],
- self.inter_channels[i])
- trans['across_lateral'] = across_lateral_trans
- # build across down trans
- if i == self.num_outs - 1 or self.across_down_trans is None:
- across_down_trans = None
- else:
- across_down_trans = self.build_trans(
- self.across_down_trans, self.inter_channels[i + 1],
- self.inter_channels[i])
- trans['across_down'] = across_down_trans
- # build across up trans
- if i == 0 or self.across_up_trans is None:
- across_up_trans = None
- else:
- across_up_trans = self.build_trans(
- self.across_up_trans, self.inter_channels[i - 1],
- self.inter_channels[i])
- trans['across_up'] = across_up_trans
- if self.across_skip_trans is None:
- across_skip_trans = None
- else:
- across_skip_trans = self.build_trans(
- self.across_skip_trans, self.inter_channels[i - 1],
- self.inter_channels[i])
- trans['across_skip'] = across_skip_trans
- # build across_skip trans
- stage_trans.append(trans)
- self.fpn_transitions.append(stage_trans)
- self.output_transition = nn.ModuleList() # output levels
- for i in range(self.num_outs):
- trans = self.build_trans(
- self.output_trans,
- self.inter_channels[i],
- self.out_channels,
- num_inputs=self.stack_times + 1)
- self.output_transition.append(trans)
- self.relu = nn.ReLU(inplace=True)
- def build_trans(self, cfg, in_channels, out_channels, **extra_args):
- cfg_ = cfg.copy()
- trans_type = cfg_.pop('type')
- trans_cls = self.transition_types[trans_type]
- return trans_cls(in_channels, out_channels, **cfg_, **extra_args)
- def fuse(self, fuse_dict):
- out = None
- for item in fuse_dict.values():
- if item is not None:
- if out is None:
- out = item
- else:
- out = out + item
- return out
- def forward(self, inputs):
- assert len(inputs) == len(self.in_channels)
- # build all levels from original feature maps
- feats = [
- lateral_conv(inputs[i + self.start_level])
- for i, lateral_conv in enumerate(self.lateral_convs)
- ]
- for downsample in self.extra_downsamples:
- feats.append(downsample(feats[-1]))
- outs = [feats]
- for i in range(self.stack_times):
- current_outs = outs[-1]
- next_outs = []
- direction = self.paths[i]
- for j in range(self.num_outs):
- if i in self.skip_inds[j]:
- next_outs.append(outs[-1][j])
- continue
- # feature level
- if direction == 'td':
- lvl = self.num_outs - j - 1
- else:
- lvl = j
- # get transitions
- if direction == 'td':
- same_trans = self.fpn_transitions[i][lvl]['same_down']
- else:
- same_trans = self.fpn_transitions[i][lvl]['same_up']
- across_lateral_trans = self.fpn_transitions[i][lvl][
- 'across_lateral']
- across_down_trans = self.fpn_transitions[i][lvl]['across_down']
- across_up_trans = self.fpn_transitions[i][lvl]['across_up']
- across_skip_trans = self.fpn_transitions[i][lvl]['across_skip']
- # init output
- to_fuse = dict(
- same=None, lateral=None, across_up=None, across_down=None)
- # same downsample/upsample
- if same_trans is not None:
- to_fuse['same'] = same_trans(next_outs[-1])
- # across lateral
- if across_lateral_trans is not None:
- to_fuse['lateral'] = across_lateral_trans(
- current_outs[lvl])
- # across downsample
- if lvl > 0 and across_up_trans is not None:
- to_fuse['across_up'] = across_up_trans(current_outs[lvl -
- 1])
- # across upsample
- if (lvl < self.num_outs - 1 and across_down_trans is not None):
- to_fuse['across_down'] = across_down_trans(
- current_outs[lvl + 1])
- if across_skip_trans is not None:
- to_fuse['across_skip'] = across_skip_trans(outs[0][lvl])
- x = self.fuse(to_fuse)
- next_outs.append(x)
- if direction == 'td':
- outs.append(next_outs[::-1])
- else:
- outs.append(next_outs)
- # output trans
- final_outs = []
- for i in range(self.num_outs):
- lvl_out_list = []
- for s in range(len(outs)):
- lvl_out_list.append(outs[s][i])
- lvl_out = self.output_transition[i](lvl_out_list)
- final_outs.append(lvl_out)
- return final_outs
|