123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400 |
- # Copyright (c) Tianheng Cheng and its affiliates. All Rights Reserved
- import math
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from mmengine.model.weight_init import caffe2_xavier_init, kaiming_init
- from torch.nn import init
- from mmdet.registry import MODELS
- def _make_stack_3x3_convs(num_convs,
- in_channels,
- out_channels,
- act_cfg=dict(type='ReLU', inplace=True)):
- convs = []
- for _ in range(num_convs):
- convs.append(nn.Conv2d(in_channels, out_channels, 3, padding=1))
- convs.append(MODELS.build(act_cfg))
- in_channels = out_channels
- return nn.Sequential(*convs)
- class InstanceBranch(nn.Module):
- def __init__(self,
- in_channels,
- dim=256,
- num_convs=4,
- num_masks=100,
- num_classes=80,
- kernel_dim=128,
- act_cfg=dict(type='ReLU', inplace=True)):
- super().__init__()
- num_masks = num_masks
- self.num_classes = num_classes
- self.inst_convs = _make_stack_3x3_convs(num_convs, in_channels, dim,
- act_cfg)
- # iam prediction, a simple conv
- self.iam_conv = nn.Conv2d(dim, num_masks, 3, padding=1)
- # outputs
- self.cls_score = nn.Linear(dim, self.num_classes)
- self.mask_kernel = nn.Linear(dim, kernel_dim)
- self.objectness = nn.Linear(dim, 1)
- self.prior_prob = 0.01
- self._init_weights()
- def _init_weights(self):
- for m in self.inst_convs.modules():
- if isinstance(m, nn.Conv2d):
- kaiming_init(m)
- bias_value = -math.log((1 - self.prior_prob) / self.prior_prob)
- for module in [self.iam_conv, self.cls_score]:
- init.constant_(module.bias, bias_value)
- init.normal_(self.iam_conv.weight, std=0.01)
- init.normal_(self.cls_score.weight, std=0.01)
- init.normal_(self.mask_kernel.weight, std=0.01)
- init.constant_(self.mask_kernel.bias, 0.0)
- def forward(self, features):
- # instance features (x4 convs)
- features = self.inst_convs(features)
- # predict instance activation maps
- iam = self.iam_conv(features)
- iam_prob = iam.sigmoid()
- B, N = iam_prob.shape[:2]
- C = features.size(1)
- # BxNxHxW -> BxNx(HW)
- iam_prob = iam_prob.view(B, N, -1)
- normalizer = iam_prob.sum(-1).clamp(min=1e-6)
- iam_prob = iam_prob / normalizer[:, :, None]
- # aggregate features: BxCxHxW -> Bx(HW)xC
- inst_features = torch.bmm(iam_prob,
- features.view(B, C, -1).permute(0, 2, 1))
- # predict classification & segmentation kernel & objectness
- pred_logits = self.cls_score(inst_features)
- pred_kernel = self.mask_kernel(inst_features)
- pred_scores = self.objectness(inst_features)
- return pred_logits, pred_kernel, pred_scores, iam
- class MaskBranch(nn.Module):
- def __init__(self,
- in_channels,
- dim=256,
- num_convs=4,
- kernel_dim=128,
- act_cfg=dict(type='ReLU', inplace=True)):
- super().__init__()
- self.mask_convs = _make_stack_3x3_convs(num_convs, in_channels, dim,
- act_cfg)
- self.projection = nn.Conv2d(dim, kernel_dim, kernel_size=1)
- self._init_weights()
- def _init_weights(self):
- for m in self.mask_convs.modules():
- if isinstance(m, nn.Conv2d):
- kaiming_init(m)
- kaiming_init(self.projection)
- def forward(self, features):
- # mask features (x4 convs)
- features = self.mask_convs(features)
- return self.projection(features)
- @MODELS.register_module()
- class BaseIAMDecoder(nn.Module):
- def __init__(self,
- in_channels,
- num_classes,
- ins_dim=256,
- ins_conv=4,
- mask_dim=256,
- mask_conv=4,
- kernel_dim=128,
- scale_factor=2.0,
- output_iam=False,
- num_masks=100,
- act_cfg=dict(type='ReLU', inplace=True)):
- super().__init__()
- # add 2 for coordinates
- in_channels = in_channels # ENCODER.NUM_CHANNELS + 2
- self.scale_factor = scale_factor
- self.output_iam = output_iam
- self.inst_branch = InstanceBranch(
- in_channels,
- dim=ins_dim,
- num_convs=ins_conv,
- num_masks=num_masks,
- num_classes=num_classes,
- kernel_dim=kernel_dim,
- act_cfg=act_cfg)
- self.mask_branch = MaskBranch(
- in_channels,
- dim=mask_dim,
- num_convs=mask_conv,
- kernel_dim=kernel_dim,
- act_cfg=act_cfg)
- @torch.no_grad()
- def compute_coordinates_linspace(self, x):
- # linspace is not supported in ONNX
- h, w = x.size(2), x.size(3)
- y_loc = torch.linspace(-1, 1, h, device=x.device)
- x_loc = torch.linspace(-1, 1, w, device=x.device)
- y_loc, x_loc = torch.meshgrid(y_loc, x_loc)
- y_loc = y_loc.expand([x.shape[0], 1, -1, -1])
- x_loc = x_loc.expand([x.shape[0], 1, -1, -1])
- locations = torch.cat([x_loc, y_loc], 1)
- return locations.to(x)
- @torch.no_grad()
- def compute_coordinates(self, x):
- h, w = x.size(2), x.size(3)
- y_loc = -1.0 + 2.0 * torch.arange(h, device=x.device) / (h - 1)
- x_loc = -1.0 + 2.0 * torch.arange(w, device=x.device) / (w - 1)
- y_loc, x_loc = torch.meshgrid(y_loc, x_loc)
- y_loc = y_loc.expand([x.shape[0], 1, -1, -1])
- x_loc = x_loc.expand([x.shape[0], 1, -1, -1])
- locations = torch.cat([x_loc, y_loc], 1)
- return locations.to(x)
- def forward(self, features):
- coord_features = self.compute_coordinates(features)
- features = torch.cat([coord_features, features], dim=1)
- pred_logits, pred_kernel, pred_scores, iam = self.inst_branch(features)
- mask_features = self.mask_branch(features)
- N = pred_kernel.shape[1]
- # mask_features: BxCxHxW
- B, C, H, W = mask_features.shape
- pred_masks = torch.bmm(pred_kernel,
- mask_features.view(B, C,
- H * W)).view(B, N, H, W)
- pred_masks = F.interpolate(
- pred_masks,
- scale_factor=self.scale_factor,
- mode='bilinear',
- align_corners=False)
- output = {
- 'pred_logits': pred_logits,
- 'pred_masks': pred_masks,
- 'pred_scores': pred_scores,
- }
- if self.output_iam:
- iam = F.interpolate(
- iam,
- scale_factor=self.scale_factor,
- mode='bilinear',
- align_corners=False)
- output['pred_iam'] = iam
- return output
- class GroupInstanceBranch(nn.Module):
- def __init__(self,
- in_channels,
- num_groups=4,
- dim=256,
- num_convs=4,
- num_masks=100,
- num_classes=80,
- kernel_dim=128,
- act_cfg=dict(type='ReLU', inplace=True)):
- super().__init__()
- self.num_groups = num_groups
- self.num_classes = num_classes
- self.inst_convs = _make_stack_3x3_convs(
- num_convs, in_channels, dim, act_cfg=act_cfg)
- # iam prediction, a group conv
- expand_dim = dim * self.num_groups
- self.iam_conv = nn.Conv2d(
- dim,
- num_masks * self.num_groups,
- 3,
- padding=1,
- groups=self.num_groups)
- # outputs
- self.fc = nn.Linear(expand_dim, expand_dim)
- self.cls_score = nn.Linear(expand_dim, self.num_classes)
- self.mask_kernel = nn.Linear(expand_dim, kernel_dim)
- self.objectness = nn.Linear(expand_dim, 1)
- self.prior_prob = 0.01
- self._init_weights()
- def _init_weights(self):
- for m in self.inst_convs.modules():
- if isinstance(m, nn.Conv2d):
- kaiming_init(m)
- bias_value = -math.log((1 - self.prior_prob) / self.prior_prob)
- for module in [self.iam_conv, self.cls_score]:
- init.constant_(module.bias, bias_value)
- init.normal_(self.iam_conv.weight, std=0.01)
- init.normal_(self.cls_score.weight, std=0.01)
- init.normal_(self.mask_kernel.weight, std=0.01)
- init.constant_(self.mask_kernel.bias, 0.0)
- caffe2_xavier_init(self.fc)
- def forward(self, features):
- # instance features (x4 convs)
- features = self.inst_convs(features)
- # predict instance activation maps
- iam = self.iam_conv(features)
- iam_prob = iam.sigmoid()
- B, N = iam_prob.shape[:2]
- C = features.size(1)
- # BxNxHxW -> BxNx(HW)
- iam_prob = iam_prob.view(B, N, -1)
- normalizer = iam_prob.sum(-1).clamp(min=1e-6)
- iam_prob = iam_prob / normalizer[:, :, None]
- # aggregate features: BxCxHxW -> Bx(HW)xC
- inst_features = torch.bmm(iam_prob,
- features.view(B, C, -1).permute(0, 2, 1))
- inst_features = inst_features.reshape(B, 4, N // self.num_groups,
- -1).transpose(1, 2).reshape(
- B, N // self.num_groups, -1)
- inst_features = F.relu_(self.fc(inst_features))
- # predict classification & segmentation kernel & objectness
- pred_logits = self.cls_score(inst_features)
- pred_kernel = self.mask_kernel(inst_features)
- pred_scores = self.objectness(inst_features)
- return pred_logits, pred_kernel, pred_scores, iam
- @MODELS.register_module()
- class GroupIAMDecoder(BaseIAMDecoder):
- def __init__(self,
- in_channels,
- num_classes,
- num_groups=4,
- ins_dim=256,
- ins_conv=4,
- mask_dim=256,
- mask_conv=4,
- kernel_dim=128,
- scale_factor=2.0,
- output_iam=False,
- num_masks=100,
- act_cfg=dict(type='ReLU', inplace=True)):
- super().__init__(
- in_channels=in_channels,
- num_classes=num_classes,
- ins_dim=ins_dim,
- ins_conv=ins_conv,
- mask_dim=mask_dim,
- mask_conv=mask_conv,
- kernel_dim=kernel_dim,
- scale_factor=scale_factor,
- output_iam=output_iam,
- num_masks=num_masks,
- act_cfg=act_cfg)
- self.inst_branch = GroupInstanceBranch(
- in_channels,
- num_groups=num_groups,
- dim=ins_dim,
- num_convs=ins_conv,
- num_masks=num_masks,
- num_classes=num_classes,
- kernel_dim=kernel_dim,
- act_cfg=act_cfg)
- class GroupInstanceSoftBranch(GroupInstanceBranch):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.softmax_bias = nn.Parameter(torch.ones([
- 1,
- ]))
- def forward(self, features):
- # instance features (x4 convs)
- features = self.inst_convs(features)
- # predict instance activation maps
- iam = self.iam_conv(features)
- B, N = iam.shape[:2]
- C = features.size(1)
- # BxNxHxW -> BxNx(HW)
- iam_prob = F.softmax(iam.view(B, N, -1) + self.softmax_bias, dim=-1)
- # aggregate features: BxCxHxW -> Bx(HW)xC
- inst_features = torch.bmm(iam_prob,
- features.view(B, C, -1).permute(0, 2, 1))
- inst_features = inst_features.reshape(B, self.num_groups,
- N // self.num_groups,
- -1).transpose(1, 2).reshape(
- B, N // self.num_groups, -1)
- inst_features = F.relu_(self.fc(inst_features))
- # predict classification & segmentation kernel & objectness
- pred_logits = self.cls_score(inst_features)
- pred_kernel = self.mask_kernel(inst_features)
- pred_scores = self.objectness(inst_features)
- return pred_logits, pred_kernel, pred_scores, iam
- @MODELS.register_module()
- class GroupIAMSoftDecoder(BaseIAMDecoder):
- def __init__(self,
- in_channels,
- num_classes,
- num_groups=4,
- ins_dim=256,
- ins_conv=4,
- mask_dim=256,
- mask_conv=4,
- kernel_dim=128,
- scale_factor=2.0,
- output_iam=False,
- num_masks=100,
- act_cfg=dict(type='ReLU', inplace=True)):
- super().__init__(
- in_channels=in_channels,
- num_classes=num_classes,
- ins_dim=ins_dim,
- ins_conv=ins_conv,
- mask_dim=mask_dim,
- mask_conv=mask_conv,
- kernel_dim=kernel_dim,
- scale_factor=scale_factor,
- output_iam=output_iam,
- num_masks=num_masks,
- act_cfg=act_cfg)
- self.inst_branch = GroupInstanceSoftBranch(
- in_channels,
- num_groups=num_groups,
- dim=ins_dim,
- num_convs=ins_conv,
- num_masks=num_masks,
- num_classes=num_classes,
- kernel_dim=kernel_dim,
- act_cfg=act_cfg)
|