123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List, Optional, Tuple
- import mmcv
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from mmcv.cnn import ConvModule
- from mmengine.structures import InstanceData
- from torch import Tensor
- from mmdet.models.utils.misc import floordiv
- from mmdet.registry import MODELS
- from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptConfigType
- from ..layers import mask_matrix_nms
- from ..utils import center_of_mass, generate_coordinate, multi_apply
- from .base_mask_head import BaseMaskHead
- @MODELS.register_module()
- class SOLOHead(BaseMaskHead):
- """SOLO mask head used in `SOLO: Segmenting Objects by Locations.
- <https://arxiv.org/abs/1912.04488>`_
- Args:
- num_classes (int): Number of categories excluding the background
- category.
- in_channels (int): Number of channels in the input feature map.
- feat_channels (int): Number of hidden channels. Used in child classes.
- Defaults to 256.
- stacked_convs (int): Number of stacking convs of the head.
- Defaults to 4.
- strides (tuple): Downsample factor of each feature map.
- scale_ranges (tuple[tuple[int, int]]): Area range of multiple
- level masks, in the format [(min1, max1), (min2, max2), ...].
- A range of (16, 64) means the area range between (16, 64).
- pos_scale (float): Constant scale factor to control the center region.
- num_grids (list[int]): Divided image into a uniform grids, each
- feature map has a different grid value. The number of output
- channels is grid ** 2. Defaults to [40, 36, 24, 16, 12].
- cls_down_index (int): The index of downsample operation in
- classification branch. Defaults to 0.
- loss_mask (dict): Config of mask loss.
- loss_cls (dict): Config of classification loss.
- norm_cfg (dict): Dictionary to construct and config norm layer.
- Defaults to norm_cfg=dict(type='GN', num_groups=32,
- requires_grad=True).
- train_cfg (dict): Training config of head.
- test_cfg (dict): Testing config of head.
- init_cfg (dict or list[dict], optional): Initialization config dict.
- """
- def __init__(
- self,
- num_classes: int,
- in_channels: int,
- feat_channels: int = 256,
- stacked_convs: int = 4,
- strides: tuple = (4, 8, 16, 32, 64),
- scale_ranges: tuple = ((8, 32), (16, 64), (32, 128), (64, 256), (128,
- 512)),
- pos_scale: float = 0.2,
- num_grids: list = [40, 36, 24, 16, 12],
- cls_down_index: int = 0,
- loss_mask: ConfigType = dict(
- type='DiceLoss', use_sigmoid=True, loss_weight=3.0),
- loss_cls: ConfigType = dict(
- type='FocalLoss',
- use_sigmoid=True,
- gamma=2.0,
- alpha=0.25,
- loss_weight=1.0),
- norm_cfg: ConfigType = dict(
- type='GN', num_groups=32, requires_grad=True),
- train_cfg: OptConfigType = None,
- test_cfg: OptConfigType = None,
- init_cfg: MultiConfig = [
- dict(type='Normal', layer='Conv2d', std=0.01),
- dict(
- type='Normal',
- std=0.01,
- bias_prob=0.01,
- override=dict(name='conv_mask_list')),
- dict(
- type='Normal',
- std=0.01,
- bias_prob=0.01,
- override=dict(name='conv_cls'))
- ]
- ) -> None:
- super().__init__(init_cfg=init_cfg)
- self.num_classes = num_classes
- self.cls_out_channels = self.num_classes
- self.in_channels = in_channels
- self.feat_channels = feat_channels
- self.stacked_convs = stacked_convs
- self.strides = strides
- self.num_grids = num_grids
- # number of FPN feats
- self.num_levels = len(strides)
- assert self.num_levels == len(scale_ranges) == len(num_grids)
- self.scale_ranges = scale_ranges
- self.pos_scale = pos_scale
- self.cls_down_index = cls_down_index
- self.loss_cls = MODELS.build(loss_cls)
- self.loss_mask = MODELS.build(loss_mask)
- self.norm_cfg = norm_cfg
- self.init_cfg = init_cfg
- self.train_cfg = train_cfg
- self.test_cfg = test_cfg
- self._init_layers()
- def _init_layers(self) -> None:
- """Initialize layers of the head."""
- self.mask_convs = nn.ModuleList()
- self.cls_convs = nn.ModuleList()
- for i in range(self.stacked_convs):
- chn = self.in_channels + 2 if i == 0 else self.feat_channels
- self.mask_convs.append(
- ConvModule(
- chn,
- self.feat_channels,
- 3,
- stride=1,
- padding=1,
- norm_cfg=self.norm_cfg))
- chn = self.in_channels if i == 0 else self.feat_channels
- self.cls_convs.append(
- ConvModule(
- chn,
- self.feat_channels,
- 3,
- stride=1,
- padding=1,
- norm_cfg=self.norm_cfg))
- self.conv_mask_list = nn.ModuleList()
- for num_grid in self.num_grids:
- self.conv_mask_list.append(
- nn.Conv2d(self.feat_channels, num_grid**2, 1))
- self.conv_cls = nn.Conv2d(
- self.feat_channels, self.cls_out_channels, 3, padding=1)
- def resize_feats(self, x: Tuple[Tensor]) -> List[Tensor]:
- """Downsample the first feat and upsample last feat in feats.
- Args:
- x (tuple[Tensor]): Features from the upstream network, each is
- a 4D-tensor.
- Returns:
- list[Tensor]: Features after resizing, each is a 4D-tensor.
- """
- out = []
- for i in range(len(x)):
- if i == 0:
- out.append(
- F.interpolate(x[0], scale_factor=0.5, mode='bilinear'))
- elif i == len(x) - 1:
- out.append(
- F.interpolate(
- x[i], size=x[i - 1].shape[-2:], mode='bilinear'))
- else:
- out.append(x[i])
- return out
- def forward(self, x: Tuple[Tensor]) -> tuple:
- """Forward features from the upstream network.
- Args:
- x (tuple[Tensor]): Features from the upstream network, each is
- a 4D-tensor.
- Returns:
- tuple: A tuple of classification scores and mask prediction.
- - mlvl_mask_preds (list[Tensor]): Multi-level mask prediction.
- Each element in the list has shape
- (batch_size, num_grids**2 ,h ,w).
- - mlvl_cls_preds (list[Tensor]): Multi-level scores.
- Each element in the list has shape
- (batch_size, num_classes, num_grids ,num_grids).
- """
- assert len(x) == self.num_levels
- feats = self.resize_feats(x)
- mlvl_mask_preds = []
- mlvl_cls_preds = []
- for i in range(self.num_levels):
- x = feats[i]
- mask_feat = x
- cls_feat = x
- # generate and concat the coordinate
- coord_feat = generate_coordinate(mask_feat.size(),
- mask_feat.device)
- mask_feat = torch.cat([mask_feat, coord_feat], 1)
- for mask_layer in (self.mask_convs):
- mask_feat = mask_layer(mask_feat)
- mask_feat = F.interpolate(
- mask_feat, scale_factor=2, mode='bilinear')
- mask_preds = self.conv_mask_list[i](mask_feat)
- # cls branch
- for j, cls_layer in enumerate(self.cls_convs):
- if j == self.cls_down_index:
- num_grid = self.num_grids[i]
- cls_feat = F.interpolate(
- cls_feat, size=num_grid, mode='bilinear')
- cls_feat = cls_layer(cls_feat)
- cls_pred = self.conv_cls(cls_feat)
- if not self.training:
- feat_wh = feats[0].size()[-2:]
- upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2)
- mask_preds = F.interpolate(
- mask_preds.sigmoid(), size=upsampled_size, mode='bilinear')
- cls_pred = cls_pred.sigmoid()
- # get local maximum
- local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1)
- keep_mask = local_max[:, :, :-1, :-1] == cls_pred
- cls_pred = cls_pred * keep_mask
- mlvl_mask_preds.append(mask_preds)
- mlvl_cls_preds.append(cls_pred)
- return mlvl_mask_preds, mlvl_cls_preds
- def loss_by_feat(self, mlvl_mask_preds: List[Tensor],
- mlvl_cls_preds: List[Tensor],
- batch_gt_instances: InstanceList,
- batch_img_metas: List[dict], **kwargs) -> dict:
- """Calculate the loss based on the features extracted by the mask head.
- Args:
- mlvl_mask_preds (list[Tensor]): Multi-level mask prediction.
- Each element in the list has shape
- (batch_size, num_grids**2 ,h ,w).
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes``, ``masks``,
- and ``labels`` attributes.
- batch_img_metas (list[dict]): Meta information of multiple images.
- Returns:
- dict[str, Tensor]: A dictionary of loss components.
- """
- num_levels = self.num_levels
- num_imgs = len(batch_img_metas)
- featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds]
- # `BoolTensor` in `pos_masks` represent
- # whether the corresponding point is
- # positive
- pos_mask_targets, labels, pos_masks = multi_apply(
- self._get_targets_single,
- batch_gt_instances,
- featmap_sizes=featmap_sizes)
- # change from the outside list meaning multi images
- # to the outside list meaning multi levels
- mlvl_pos_mask_targets = [[] for _ in range(num_levels)]
- mlvl_pos_mask_preds = [[] for _ in range(num_levels)]
- mlvl_pos_masks = [[] for _ in range(num_levels)]
- mlvl_labels = [[] for _ in range(num_levels)]
- for img_id in range(num_imgs):
- assert num_levels == len(pos_mask_targets[img_id])
- for lvl in range(num_levels):
- mlvl_pos_mask_targets[lvl].append(
- pos_mask_targets[img_id][lvl])
- mlvl_pos_mask_preds[lvl].append(
- mlvl_mask_preds[lvl][img_id, pos_masks[img_id][lvl], ...])
- mlvl_pos_masks[lvl].append(pos_masks[img_id][lvl].flatten())
- mlvl_labels[lvl].append(labels[img_id][lvl].flatten())
- # cat multiple image
- temp_mlvl_cls_preds = []
- for lvl in range(num_levels):
- mlvl_pos_mask_targets[lvl] = torch.cat(
- mlvl_pos_mask_targets[lvl], dim=0)
- mlvl_pos_mask_preds[lvl] = torch.cat(
- mlvl_pos_mask_preds[lvl], dim=0)
- mlvl_pos_masks[lvl] = torch.cat(mlvl_pos_masks[lvl], dim=0)
- mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0)
- temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute(
- 0, 2, 3, 1).reshape(-1, self.cls_out_channels))
- num_pos = sum(item.sum() for item in mlvl_pos_masks)
- # dice loss
- loss_mask = []
- for pred, target in zip(mlvl_pos_mask_preds, mlvl_pos_mask_targets):
- if pred.size()[0] == 0:
- loss_mask.append(pred.sum().unsqueeze(0))
- continue
- loss_mask.append(
- self.loss_mask(pred, target, reduction_override='none'))
- if num_pos > 0:
- loss_mask = torch.cat(loss_mask).sum() / num_pos
- else:
- loss_mask = torch.cat(loss_mask).mean()
- flatten_labels = torch.cat(mlvl_labels)
- flatten_cls_preds = torch.cat(temp_mlvl_cls_preds)
- loss_cls = self.loss_cls(
- flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1)
- return dict(loss_mask=loss_mask, loss_cls=loss_cls)
- def _get_targets_single(self,
- gt_instances: InstanceData,
- featmap_sizes: Optional[list] = None) -> tuple:
- """Compute targets for predictions of single image.
- Args:
- gt_instances (:obj:`InstanceData`): Ground truth of instance
- annotations. It should includes ``bboxes``, ``labels``,
- and ``masks`` attributes.
- featmap_sizes (list[:obj:`torch.size`]): Size of each
- feature map from feature pyramid, each element
- means (feat_h, feat_w). Defaults to None.
- Returns:
- Tuple: Usually returns a tuple containing targets for predictions.
- - mlvl_pos_mask_targets (list[Tensor]): Each element represent
- the binary mask targets for positive points in this
- level, has shape (num_pos, out_h, out_w).
- - mlvl_labels (list[Tensor]): Each element is
- classification labels for all
- points in this level, has shape
- (num_grid, num_grid).
- - mlvl_pos_masks (list[Tensor]): Each element is
- a `BoolTensor` to represent whether the
- corresponding point in single level
- is positive, has shape (num_grid **2).
- """
- gt_labels = gt_instances.labels
- device = gt_labels.device
- gt_bboxes = gt_instances.bboxes
- gt_areas = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) *
- (gt_bboxes[:, 3] - gt_bboxes[:, 1]))
- gt_masks = gt_instances.masks.to_tensor(
- dtype=torch.bool, device=device)
- mlvl_pos_mask_targets = []
- mlvl_labels = []
- mlvl_pos_masks = []
- for (lower_bound, upper_bound), stride, featmap_size, num_grid \
- in zip(self.scale_ranges, self.strides,
- featmap_sizes, self.num_grids):
- mask_target = torch.zeros(
- [num_grid**2, featmap_size[0], featmap_size[1]],
- dtype=torch.uint8,
- device=device)
- # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
- labels = torch.zeros([num_grid, num_grid],
- dtype=torch.int64,
- device=device) + self.num_classes
- pos_mask = torch.zeros([num_grid**2],
- dtype=torch.bool,
- device=device)
- gt_inds = ((gt_areas >= lower_bound) &
- (gt_areas <= upper_bound)).nonzero().flatten()
- if len(gt_inds) == 0:
- mlvl_pos_mask_targets.append(
- mask_target.new_zeros(0, featmap_size[0], featmap_size[1]))
- mlvl_labels.append(labels)
- mlvl_pos_masks.append(pos_mask)
- continue
- hit_gt_bboxes = gt_bboxes[gt_inds]
- hit_gt_labels = gt_labels[gt_inds]
- hit_gt_masks = gt_masks[gt_inds, ...]
- pos_w_ranges = 0.5 * (hit_gt_bboxes[:, 2] -
- hit_gt_bboxes[:, 0]) * self.pos_scale
- pos_h_ranges = 0.5 * (hit_gt_bboxes[:, 3] -
- hit_gt_bboxes[:, 1]) * self.pos_scale
- # Make sure hit_gt_masks has a value
- valid_mask_flags = hit_gt_masks.sum(dim=-1).sum(dim=-1) > 0
- output_stride = stride / 2
- for gt_mask, gt_label, pos_h_range, pos_w_range, \
- valid_mask_flag in \
- zip(hit_gt_masks, hit_gt_labels, pos_h_ranges,
- pos_w_ranges, valid_mask_flags):
- if not valid_mask_flag:
- continue
- upsampled_size = (featmap_sizes[0][0] * 4,
- featmap_sizes[0][1] * 4)
- center_h, center_w = center_of_mass(gt_mask)
- coord_w = int(
- floordiv((center_w / upsampled_size[1]), (1. / num_grid),
- rounding_mode='trunc'))
- coord_h = int(
- floordiv((center_h / upsampled_size[0]), (1. / num_grid),
- rounding_mode='trunc'))
- # left, top, right, down
- top_box = max(
- 0,
- int(
- floordiv(
- (center_h - pos_h_range) / upsampled_size[0],
- (1. / num_grid),
- rounding_mode='trunc')))
- down_box = min(
- num_grid - 1,
- int(
- floordiv(
- (center_h + pos_h_range) / upsampled_size[0],
- (1. / num_grid),
- rounding_mode='trunc')))
- left_box = max(
- 0,
- int(
- floordiv(
- (center_w - pos_w_range) / upsampled_size[1],
- (1. / num_grid),
- rounding_mode='trunc')))
- right_box = min(
- num_grid - 1,
- int(
- floordiv(
- (center_w + pos_w_range) / upsampled_size[1],
- (1. / num_grid),
- rounding_mode='trunc')))
- top = max(top_box, coord_h - 1)
- down = min(down_box, coord_h + 1)
- left = max(coord_w - 1, left_box)
- right = min(right_box, coord_w + 1)
- labels[top:(down + 1), left:(right + 1)] = gt_label
- # ins
- gt_mask = np.uint8(gt_mask.cpu().numpy())
- # Follow the original implementation, F.interpolate is
- # different from cv2 and opencv
- gt_mask = mmcv.imrescale(gt_mask, scale=1. / output_stride)
- gt_mask = torch.from_numpy(gt_mask).to(device=device)
- for i in range(top, down + 1):
- for j in range(left, right + 1):
- index = int(i * num_grid + j)
- mask_target[index, :gt_mask.shape[0], :gt_mask.
- shape[1]] = gt_mask
- pos_mask[index] = True
- mlvl_pos_mask_targets.append(mask_target[pos_mask])
- mlvl_labels.append(labels)
- mlvl_pos_masks.append(pos_mask)
- return mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks
- def predict_by_feat(self, mlvl_mask_preds: List[Tensor],
- mlvl_cls_scores: List[Tensor],
- batch_img_metas: List[dict], **kwargs) -> InstanceList:
- """Transform a batch of output features extracted from the head into
- mask results.
- Args:
- mlvl_mask_preds (list[Tensor]): Multi-level mask prediction.
- Each element in the list has shape
- (batch_size, num_grids**2 ,h ,w).
- mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element
- in the list has shape
- (batch_size, num_classes, num_grids ,num_grids).
- batch_img_metas (list[dict]): Meta information of all images.
- Returns:
- list[:obj:`InstanceData`]: Processed results of multiple
- images.Each :obj:`InstanceData` usually contains
- following keys.
- - scores (Tensor): Classification scores, has shape
- (num_instance,).
- - labels (Tensor): Has shape (num_instances,).
- - masks (Tensor): Processed mask results, has
- shape (num_instances, h, w).
- """
- mlvl_cls_scores = [
- item.permute(0, 2, 3, 1) for item in mlvl_cls_scores
- ]
- assert len(mlvl_mask_preds) == len(mlvl_cls_scores)
- num_levels = len(mlvl_cls_scores)
- results_list = []
- for img_id in range(len(batch_img_metas)):
- cls_pred_list = [
- mlvl_cls_scores[lvl][img_id].view(-1, self.cls_out_channels)
- for lvl in range(num_levels)
- ]
- mask_pred_list = [
- mlvl_mask_preds[lvl][img_id] for lvl in range(num_levels)
- ]
- cls_pred_list = torch.cat(cls_pred_list, dim=0)
- mask_pred_list = torch.cat(mask_pred_list, dim=0)
- img_meta = batch_img_metas[img_id]
- results = self._predict_by_feat_single(
- cls_pred_list, mask_pred_list, img_meta=img_meta)
- results_list.append(results)
- return results_list
- def _predict_by_feat_single(self,
- cls_scores: Tensor,
- mask_preds: Tensor,
- img_meta: dict,
- cfg: OptConfigType = None) -> InstanceData:
- """Transform a single image's features extracted from the head into
- mask results.
- Args:
- cls_scores (Tensor): Classification score of all points
- in single image, has shape (num_points, num_classes).
- mask_preds (Tensor): Mask prediction of all points in
- single image, has shape (num_points, feat_h, feat_w).
- img_meta (dict): Meta information of corresponding image.
- cfg (dict, optional): Config used in test phase.
- Defaults to None.
- Returns:
- :obj:`InstanceData`: Processed results of single image.
- it usually contains following keys.
- - scores (Tensor): Classification scores, has shape
- (num_instance,).
- - labels (Tensor): Has shape (num_instances,).
- - masks (Tensor): Processed mask results, has
- shape (num_instances, h, w).
- """
- def empty_results(cls_scores, ori_shape):
- """Generate a empty results."""
- results = InstanceData()
- results.scores = cls_scores.new_ones(0)
- results.masks = cls_scores.new_zeros(0, *ori_shape)
- results.labels = cls_scores.new_ones(0)
- results.bboxes = cls_scores.new_zeros(0, 4)
- return results
- cfg = self.test_cfg if cfg is None else cfg
- assert len(cls_scores) == len(mask_preds)
- featmap_size = mask_preds.size()[-2:]
- h, w = img_meta['img_shape'][:2]
- upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4)
- score_mask = (cls_scores > cfg.score_thr)
- cls_scores = cls_scores[score_mask]
- if len(cls_scores) == 0:
- return empty_results(cls_scores, img_meta['ori_shape'][:2])
- inds = score_mask.nonzero()
- cls_labels = inds[:, 1]
- # Filter the mask mask with an area is smaller than
- # stride of corresponding feature level
- lvl_interval = cls_labels.new_tensor(self.num_grids).pow(2).cumsum(0)
- strides = cls_scores.new_ones(lvl_interval[-1])
- strides[:lvl_interval[0]] *= self.strides[0]
- for lvl in range(1, self.num_levels):
- strides[lvl_interval[lvl -
- 1]:lvl_interval[lvl]] *= self.strides[lvl]
- strides = strides[inds[:, 0]]
- mask_preds = mask_preds[inds[:, 0]]
- masks = mask_preds > cfg.mask_thr
- sum_masks = masks.sum((1, 2)).float()
- keep = sum_masks > strides
- if keep.sum() == 0:
- return empty_results(cls_scores, img_meta['ori_shape'][:2])
- masks = masks[keep]
- mask_preds = mask_preds[keep]
- sum_masks = sum_masks[keep]
- cls_scores = cls_scores[keep]
- cls_labels = cls_labels[keep]
- # maskness.
- mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks
- cls_scores *= mask_scores
- scores, labels, _, keep_inds = mask_matrix_nms(
- masks,
- cls_labels,
- cls_scores,
- mask_area=sum_masks,
- nms_pre=cfg.nms_pre,
- max_num=cfg.max_per_img,
- kernel=cfg.kernel,
- sigma=cfg.sigma,
- filter_thr=cfg.filter_thr)
- # mask_matrix_nms may return an empty Tensor
- if len(keep_inds) == 0:
- return empty_results(cls_scores, img_meta['ori_shape'][:2])
- mask_preds = mask_preds[keep_inds]
- mask_preds = F.interpolate(
- mask_preds.unsqueeze(0), size=upsampled_size,
- mode='bilinear')[:, :, :h, :w]
- mask_preds = F.interpolate(
- mask_preds, size=img_meta['ori_shape'][:2],
- mode='bilinear').squeeze(0)
- masks = mask_preds > cfg.mask_thr
- results = InstanceData()
- results.masks = masks
- results.labels = labels
- results.scores = scores
- # create an empty bbox in InstanceData to avoid bugs when
- # calculating metrics.
- results.bboxes = results.scores.new_zeros(len(scores), 4)
- return results
- @MODELS.register_module()
- class DecoupledSOLOHead(SOLOHead):
- """Decoupled SOLO mask head used in `SOLO: Segmenting Objects by Locations.
- <https://arxiv.org/abs/1912.04488>`_
- Args:
- init_cfg (dict or list[dict], optional): Initialization config dict.
- """
- def __init__(self,
- *args,
- init_cfg: MultiConfig = [
- dict(type='Normal', layer='Conv2d', std=0.01),
- dict(
- type='Normal',
- std=0.01,
- bias_prob=0.01,
- override=dict(name='conv_mask_list_x')),
- dict(
- type='Normal',
- std=0.01,
- bias_prob=0.01,
- override=dict(name='conv_mask_list_y')),
- dict(
- type='Normal',
- std=0.01,
- bias_prob=0.01,
- override=dict(name='conv_cls'))
- ],
- **kwargs) -> None:
- super().__init__(*args, init_cfg=init_cfg, **kwargs)
- def _init_layers(self) -> None:
- self.mask_convs_x = nn.ModuleList()
- self.mask_convs_y = nn.ModuleList()
- self.cls_convs = nn.ModuleList()
- for i in range(self.stacked_convs):
- chn = self.in_channels + 1 if i == 0 else self.feat_channels
- self.mask_convs_x.append(
- ConvModule(
- chn,
- self.feat_channels,
- 3,
- stride=1,
- padding=1,
- norm_cfg=self.norm_cfg))
- self.mask_convs_y.append(
- ConvModule(
- chn,
- self.feat_channels,
- 3,
- stride=1,
- padding=1,
- norm_cfg=self.norm_cfg))
- chn = self.in_channels if i == 0 else self.feat_channels
- self.cls_convs.append(
- ConvModule(
- chn,
- self.feat_channels,
- 3,
- stride=1,
- padding=1,
- norm_cfg=self.norm_cfg))
- self.conv_mask_list_x = nn.ModuleList()
- self.conv_mask_list_y = nn.ModuleList()
- for num_grid in self.num_grids:
- self.conv_mask_list_x.append(
- nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
- self.conv_mask_list_y.append(
- nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
- self.conv_cls = nn.Conv2d(
- self.feat_channels, self.cls_out_channels, 3, padding=1)
- def forward(self, x: Tuple[Tensor]) -> Tuple:
- """Forward features from the upstream network.
- Args:
- x (tuple[Tensor]): Features from the upstream network, each is
- a 4D-tensor.
- Returns:
- tuple: A tuple of classification scores and mask prediction.
- - mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
- from x branch. Each element in the list has shape
- (batch_size, num_grids ,h ,w).
- - mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction
- from y branch. Each element in the list has shape
- (batch_size, num_grids ,h ,w).
- - mlvl_cls_preds (list[Tensor]): Multi-level scores.
- Each element in the list has shape
- (batch_size, num_classes, num_grids ,num_grids).
- """
- assert len(x) == self.num_levels
- feats = self.resize_feats(x)
- mask_preds_x = []
- mask_preds_y = []
- cls_preds = []
- for i in range(self.num_levels):
- x = feats[i]
- mask_feat = x
- cls_feat = x
- # generate and concat the coordinate
- coord_feat = generate_coordinate(mask_feat.size(),
- mask_feat.device)
- mask_feat_x = torch.cat([mask_feat, coord_feat[:, 0:1, ...]], 1)
- mask_feat_y = torch.cat([mask_feat, coord_feat[:, 1:2, ...]], 1)
- for mask_layer_x, mask_layer_y in \
- zip(self.mask_convs_x, self.mask_convs_y):
- mask_feat_x = mask_layer_x(mask_feat_x)
- mask_feat_y = mask_layer_y(mask_feat_y)
- mask_feat_x = F.interpolate(
- mask_feat_x, scale_factor=2, mode='bilinear')
- mask_feat_y = F.interpolate(
- mask_feat_y, scale_factor=2, mode='bilinear')
- mask_pred_x = self.conv_mask_list_x[i](mask_feat_x)
- mask_pred_y = self.conv_mask_list_y[i](mask_feat_y)
- # cls branch
- for j, cls_layer in enumerate(self.cls_convs):
- if j == self.cls_down_index:
- num_grid = self.num_grids[i]
- cls_feat = F.interpolate(
- cls_feat, size=num_grid, mode='bilinear')
- cls_feat = cls_layer(cls_feat)
- cls_pred = self.conv_cls(cls_feat)
- if not self.training:
- feat_wh = feats[0].size()[-2:]
- upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2)
- mask_pred_x = F.interpolate(
- mask_pred_x.sigmoid(),
- size=upsampled_size,
- mode='bilinear')
- mask_pred_y = F.interpolate(
- mask_pred_y.sigmoid(),
- size=upsampled_size,
- mode='bilinear')
- cls_pred = cls_pred.sigmoid()
- # get local maximum
- local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1)
- keep_mask = local_max[:, :, :-1, :-1] == cls_pred
- cls_pred = cls_pred * keep_mask
- mask_preds_x.append(mask_pred_x)
- mask_preds_y.append(mask_pred_y)
- cls_preds.append(cls_pred)
- return mask_preds_x, mask_preds_y, cls_preds
- def loss_by_feat(self, mlvl_mask_preds_x: List[Tensor],
- mlvl_mask_preds_y: List[Tensor],
- mlvl_cls_preds: List[Tensor],
- batch_gt_instances: InstanceList,
- batch_img_metas: List[dict], **kwargs) -> dict:
- """Calculate the loss based on the features extracted by the mask head.
- Args:
- mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
- from x branch. Each element in the list has shape
- (batch_size, num_grids ,h ,w).
- mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction
- from y branch. Each element in the list has shape
- (batch_size, num_grids ,h ,w).
- mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element
- in the list has shape
- (batch_size, num_classes, num_grids ,num_grids).
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes``, ``masks``,
- and ``labels`` attributes.
- batch_img_metas (list[dict]): Meta information of multiple images.
- Returns:
- dict[str, Tensor]: A dictionary of loss components.
- """
- num_levels = self.num_levels
- num_imgs = len(batch_img_metas)
- featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds_x]
- pos_mask_targets, labels, xy_pos_indexes = multi_apply(
- self._get_targets_single,
- batch_gt_instances,
- featmap_sizes=featmap_sizes)
- # change from the outside list meaning multi images
- # to the outside list meaning multi levels
- mlvl_pos_mask_targets = [[] for _ in range(num_levels)]
- mlvl_pos_mask_preds_x = [[] for _ in range(num_levels)]
- mlvl_pos_mask_preds_y = [[] for _ in range(num_levels)]
- mlvl_labels = [[] for _ in range(num_levels)]
- for img_id in range(num_imgs):
- for lvl in range(num_levels):
- mlvl_pos_mask_targets[lvl].append(
- pos_mask_targets[img_id][lvl])
- mlvl_pos_mask_preds_x[lvl].append(
- mlvl_mask_preds_x[lvl][img_id,
- xy_pos_indexes[img_id][lvl][:, 1]])
- mlvl_pos_mask_preds_y[lvl].append(
- mlvl_mask_preds_y[lvl][img_id,
- xy_pos_indexes[img_id][lvl][:, 0]])
- mlvl_labels[lvl].append(labels[img_id][lvl].flatten())
- # cat multiple image
- temp_mlvl_cls_preds = []
- for lvl in range(num_levels):
- mlvl_pos_mask_targets[lvl] = torch.cat(
- mlvl_pos_mask_targets[lvl], dim=0)
- mlvl_pos_mask_preds_x[lvl] = torch.cat(
- mlvl_pos_mask_preds_x[lvl], dim=0)
- mlvl_pos_mask_preds_y[lvl] = torch.cat(
- mlvl_pos_mask_preds_y[lvl], dim=0)
- mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0)
- temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute(
- 0, 2, 3, 1).reshape(-1, self.cls_out_channels))
- num_pos = 0.
- # dice loss
- loss_mask = []
- for pred_x, pred_y, target in \
- zip(mlvl_pos_mask_preds_x,
- mlvl_pos_mask_preds_y, mlvl_pos_mask_targets):
- num_masks = pred_x.size(0)
- if num_masks == 0:
- # make sure can get grad
- loss_mask.append((pred_x.sum() + pred_y.sum()).unsqueeze(0))
- continue
- num_pos += num_masks
- pred_mask = pred_y.sigmoid() * pred_x.sigmoid()
- loss_mask.append(
- self.loss_mask(pred_mask, target, reduction_override='none'))
- if num_pos > 0:
- loss_mask = torch.cat(loss_mask).sum() / num_pos
- else:
- loss_mask = torch.cat(loss_mask).mean()
- # cate
- flatten_labels = torch.cat(mlvl_labels)
- flatten_cls_preds = torch.cat(temp_mlvl_cls_preds)
- loss_cls = self.loss_cls(
- flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1)
- return dict(loss_mask=loss_mask, loss_cls=loss_cls)
- def _get_targets_single(self,
- gt_instances: InstanceData,
- featmap_sizes: Optional[list] = None) -> tuple:
- """Compute targets for predictions of single image.
- Args:
- gt_instances (:obj:`InstanceData`): Ground truth of instance
- annotations. It should includes ``bboxes``, ``labels``,
- and ``masks`` attributes.
- featmap_sizes (list[:obj:`torch.size`]): Size of each
- feature map from feature pyramid, each element
- means (feat_h, feat_w). Defaults to None.
- Returns:
- Tuple: Usually returns a tuple containing targets for predictions.
- - mlvl_pos_mask_targets (list[Tensor]): Each element represent
- the binary mask targets for positive points in this
- level, has shape (num_pos, out_h, out_w).
- - mlvl_labels (list[Tensor]): Each element is
- classification labels for all
- points in this level, has shape
- (num_grid, num_grid).
- - mlvl_xy_pos_indexes (list[Tensor]): Each element
- in the list contains the index of positive samples in
- corresponding level, has shape (num_pos, 2), last
- dimension 2 present (index_x, index_y).
- """
- mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks = \
- super()._get_targets_single(gt_instances,
- featmap_sizes=featmap_sizes)
- mlvl_xy_pos_indexes = [(item - self.num_classes).nonzero()
- for item in mlvl_labels]
- return mlvl_pos_mask_targets, mlvl_labels, mlvl_xy_pos_indexes
- def predict_by_feat(self, mlvl_mask_preds_x: List[Tensor],
- mlvl_mask_preds_y: List[Tensor],
- mlvl_cls_scores: List[Tensor],
- batch_img_metas: List[dict], **kwargs) -> InstanceList:
- """Transform a batch of output features extracted from the head into
- mask results.
- Args:
- mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
- from x branch. Each element in the list has shape
- (batch_size, num_grids ,h ,w).
- mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction
- from y branch. Each element in the list has shape
- (batch_size, num_grids ,h ,w).
- mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element
- in the list has shape
- (batch_size, num_classes ,num_grids ,num_grids).
- batch_img_metas (list[dict]): Meta information of all images.
- Returns:
- list[:obj:`InstanceData`]: Processed results of multiple
- images.Each :obj:`InstanceData` usually contains
- following keys.
- - scores (Tensor): Classification scores, has shape
- (num_instance,).
- - labels (Tensor): Has shape (num_instances,).
- - masks (Tensor): Processed mask results, has
- shape (num_instances, h, w).
- """
- mlvl_cls_scores = [
- item.permute(0, 2, 3, 1) for item in mlvl_cls_scores
- ]
- assert len(mlvl_mask_preds_x) == len(mlvl_cls_scores)
- num_levels = len(mlvl_cls_scores)
- results_list = []
- for img_id in range(len(batch_img_metas)):
- cls_pred_list = [
- mlvl_cls_scores[i][img_id].view(
- -1, self.cls_out_channels).detach()
- for i in range(num_levels)
- ]
- mask_pred_list_x = [
- mlvl_mask_preds_x[i][img_id] for i in range(num_levels)
- ]
- mask_pred_list_y = [
- mlvl_mask_preds_y[i][img_id] for i in range(num_levels)
- ]
- cls_pred_list = torch.cat(cls_pred_list, dim=0)
- mask_pred_list_x = torch.cat(mask_pred_list_x, dim=0)
- mask_pred_list_y = torch.cat(mask_pred_list_y, dim=0)
- img_meta = batch_img_metas[img_id]
- results = self._predict_by_feat_single(
- cls_pred_list,
- mask_pred_list_x,
- mask_pred_list_y,
- img_meta=img_meta)
- results_list.append(results)
- return results_list
- def _predict_by_feat_single(self,
- cls_scores: Tensor,
- mask_preds_x: Tensor,
- mask_preds_y: Tensor,
- img_meta: dict,
- cfg: OptConfigType = None) -> InstanceData:
- """Transform a single image's features extracted from the head into
- mask results.
- Args:
- cls_scores (Tensor): Classification score of all points
- in single image, has shape (num_points, num_classes).
- mask_preds_x (Tensor): Mask prediction of x branch of
- all points in single image, has shape
- (sum_num_grids, feat_h, feat_w).
- mask_preds_y (Tensor): Mask prediction of y branch of
- all points in single image, has shape
- (sum_num_grids, feat_h, feat_w).
- img_meta (dict): Meta information of corresponding image.
- cfg (dict): Config used in test phase.
- Returns:
- :obj:`InstanceData`: Processed results of single image.
- it usually contains following keys.
- - scores (Tensor): Classification scores, has shape
- (num_instance,).
- - labels (Tensor): Has shape (num_instances,).
- - masks (Tensor): Processed mask results, has
- shape (num_instances, h, w).
- """
- def empty_results(cls_scores, ori_shape):
- """Generate a empty results."""
- results = InstanceData()
- results.scores = cls_scores.new_ones(0)
- results.masks = cls_scores.new_zeros(0, *ori_shape)
- results.labels = cls_scores.new_ones(0)
- results.bboxes = cls_scores.new_zeros(0, 4)
- return results
- cfg = self.test_cfg if cfg is None else cfg
- featmap_size = mask_preds_x.size()[-2:]
- h, w = img_meta['img_shape'][:2]
- upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4)
- score_mask = (cls_scores > cfg.score_thr)
- cls_scores = cls_scores[score_mask]
- inds = score_mask.nonzero()
- lvl_interval = inds.new_tensor(self.num_grids).pow(2).cumsum(0)
- num_all_points = lvl_interval[-1]
- lvl_start_index = inds.new_ones(num_all_points)
- num_grids = inds.new_ones(num_all_points)
- seg_size = inds.new_tensor(self.num_grids).cumsum(0)
- mask_lvl_start_index = inds.new_ones(num_all_points)
- strides = inds.new_ones(num_all_points)
- lvl_start_index[:lvl_interval[0]] *= 0
- mask_lvl_start_index[:lvl_interval[0]] *= 0
- num_grids[:lvl_interval[0]] *= self.num_grids[0]
- strides[:lvl_interval[0]] *= self.strides[0]
- for lvl in range(1, self.num_levels):
- lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
- lvl_interval[lvl - 1]
- mask_lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
- seg_size[lvl - 1]
- num_grids[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
- self.num_grids[lvl]
- strides[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
- self.strides[lvl]
- lvl_start_index = lvl_start_index[inds[:, 0]]
- mask_lvl_start_index = mask_lvl_start_index[inds[:, 0]]
- num_grids = num_grids[inds[:, 0]]
- strides = strides[inds[:, 0]]
- y_lvl_offset = (inds[:, 0] - lvl_start_index) // num_grids
- x_lvl_offset = (inds[:, 0] - lvl_start_index) % num_grids
- y_inds = mask_lvl_start_index + y_lvl_offset
- x_inds = mask_lvl_start_index + x_lvl_offset
- cls_labels = inds[:, 1]
- mask_preds = mask_preds_x[x_inds, ...] * mask_preds_y[y_inds, ...]
- masks = mask_preds > cfg.mask_thr
- sum_masks = masks.sum((1, 2)).float()
- keep = sum_masks > strides
- if keep.sum() == 0:
- return empty_results(cls_scores, img_meta['ori_shape'][:2])
- masks = masks[keep]
- mask_preds = mask_preds[keep]
- sum_masks = sum_masks[keep]
- cls_scores = cls_scores[keep]
- cls_labels = cls_labels[keep]
- # maskness.
- mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks
- cls_scores *= mask_scores
- scores, labels, _, keep_inds = mask_matrix_nms(
- masks,
- cls_labels,
- cls_scores,
- mask_area=sum_masks,
- nms_pre=cfg.nms_pre,
- max_num=cfg.max_per_img,
- kernel=cfg.kernel,
- sigma=cfg.sigma,
- filter_thr=cfg.filter_thr)
- # mask_matrix_nms may return an empty Tensor
- if len(keep_inds) == 0:
- return empty_results(cls_scores, img_meta['ori_shape'][:2])
- mask_preds = mask_preds[keep_inds]
- mask_preds = F.interpolate(
- mask_preds.unsqueeze(0), size=upsampled_size,
- mode='bilinear')[:, :, :h, :w]
- mask_preds = F.interpolate(
- mask_preds, size=img_meta['ori_shape'][:2],
- mode='bilinear').squeeze(0)
- masks = mask_preds > cfg.mask_thr
- results = InstanceData()
- results.masks = masks
- results.labels = labels
- results.scores = scores
- # create an empty bbox in InstanceData to avoid bugs when
- # calculating metrics.
- results.bboxes = results.scores.new_zeros(len(scores), 4)
- return results
- @MODELS.register_module()
- class DecoupledSOLOLightHead(DecoupledSOLOHead):
- """Decoupled Light SOLO mask head used in `SOLO: Segmenting Objects by
- Locations <https://arxiv.org/abs/1912.04488>`_
- Args:
- with_dcn (bool): Whether use dcn in mask_convs and cls_convs,
- Defaults to False.
- init_cfg (dict or list[dict], optional): Initialization config dict.
- """
- def __init__(self,
- *args,
- dcn_cfg: OptConfigType = None,
- init_cfg: MultiConfig = [
- dict(type='Normal', layer='Conv2d', std=0.01),
- dict(
- type='Normal',
- std=0.01,
- bias_prob=0.01,
- override=dict(name='conv_mask_list_x')),
- dict(
- type='Normal',
- std=0.01,
- bias_prob=0.01,
- override=dict(name='conv_mask_list_y')),
- dict(
- type='Normal',
- std=0.01,
- bias_prob=0.01,
- override=dict(name='conv_cls'))
- ],
- **kwargs) -> None:
- assert dcn_cfg is None or isinstance(dcn_cfg, dict)
- self.dcn_cfg = dcn_cfg
- super().__init__(*args, init_cfg=init_cfg, **kwargs)
- def _init_layers(self) -> None:
- self.mask_convs = nn.ModuleList()
- self.cls_convs = nn.ModuleList()
- for i in range(self.stacked_convs):
- if self.dcn_cfg is not None \
- and i == self.stacked_convs - 1:
- conv_cfg = self.dcn_cfg
- else:
- conv_cfg = None
- chn = self.in_channels + 2 if i == 0 else self.feat_channels
- self.mask_convs.append(
- ConvModule(
- chn,
- self.feat_channels,
- 3,
- stride=1,
- padding=1,
- conv_cfg=conv_cfg,
- norm_cfg=self.norm_cfg))
- chn = self.in_channels if i == 0 else self.feat_channels
- self.cls_convs.append(
- ConvModule(
- chn,
- self.feat_channels,
- 3,
- stride=1,
- padding=1,
- conv_cfg=conv_cfg,
- norm_cfg=self.norm_cfg))
- self.conv_mask_list_x = nn.ModuleList()
- self.conv_mask_list_y = nn.ModuleList()
- for num_grid in self.num_grids:
- self.conv_mask_list_x.append(
- nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
- self.conv_mask_list_y.append(
- nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
- self.conv_cls = nn.Conv2d(
- self.feat_channels, self.cls_out_channels, 3, padding=1)
- def forward(self, x: Tuple[Tensor]) -> Tuple:
- """Forward features from the upstream network.
- Args:
- x (tuple[Tensor]): Features from the upstream network, each is
- a 4D-tensor.
- Returns:
- tuple: A tuple of classification scores and mask prediction.
- - mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
- from x branch. Each element in the list has shape
- (batch_size, num_grids ,h ,w).
- - mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction
- from y branch. Each element in the list has shape
- (batch_size, num_grids ,h ,w).
- - mlvl_cls_preds (list[Tensor]): Multi-level scores.
- Each element in the list has shape
- (batch_size, num_classes, num_grids ,num_grids).
- """
- assert len(x) == self.num_levels
- feats = self.resize_feats(x)
- mask_preds_x = []
- mask_preds_y = []
- cls_preds = []
- for i in range(self.num_levels):
- x = feats[i]
- mask_feat = x
- cls_feat = x
- # generate and concat the coordinate
- coord_feat = generate_coordinate(mask_feat.size(),
- mask_feat.device)
- mask_feat = torch.cat([mask_feat, coord_feat], 1)
- for mask_layer in self.mask_convs:
- mask_feat = mask_layer(mask_feat)
- mask_feat = F.interpolate(
- mask_feat, scale_factor=2, mode='bilinear')
- mask_pred_x = self.conv_mask_list_x[i](mask_feat)
- mask_pred_y = self.conv_mask_list_y[i](mask_feat)
- # cls branch
- for j, cls_layer in enumerate(self.cls_convs):
- if j == self.cls_down_index:
- num_grid = self.num_grids[i]
- cls_feat = F.interpolate(
- cls_feat, size=num_grid, mode='bilinear')
- cls_feat = cls_layer(cls_feat)
- cls_pred = self.conv_cls(cls_feat)
- if not self.training:
- feat_wh = feats[0].size()[-2:]
- upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2)
- mask_pred_x = F.interpolate(
- mask_pred_x.sigmoid(),
- size=upsampled_size,
- mode='bilinear')
- mask_pred_y = F.interpolate(
- mask_pred_y.sigmoid(),
- size=upsampled_size,
- mode='bilinear')
- cls_pred = cls_pred.sigmoid()
- # get local maximum
- local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1)
- keep_mask = local_max[:, :, :-1, :-1] == cls_pred
- cls_pred = cls_pred * keep_mask
- mask_preds_x.append(mask_pred_x)
- mask_preds_y.append(mask_pred_y)
- cls_preds.append(cls_pred)
- return mask_preds_x, mask_preds_y, cls_preds
|