123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Dict, List, Optional, Tuple, Union
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from mmcv.cnn import Conv2d
- from mmengine.model import caffe2_xavier_init
- from mmengine.structures import InstanceData, PixelData
- from torch import Tensor
- from mmdet.models.layers.pixel_decoder import PixelDecoder
- from mmdet.registry import MODELS, TASK_UTILS
- from mmdet.structures import SampleList
- from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
- OptMultiConfig, reduce_mean)
- from ..layers import DetrTransformerDecoder, SinePositionalEncoding
- from ..utils import multi_apply, preprocess_panoptic_gt
- from .anchor_free_head import AnchorFreeHead
- @MODELS.register_module()
- class MaskFormerHead(AnchorFreeHead):
- """Implements the MaskFormer head.
- See `Per-Pixel Classification is Not All You Need for Semantic
- Segmentation <https://arxiv.org/pdf/2107.06278>`_ for details.
- Args:
- in_channels (list[int]): Number of channels in the input feature map.
- feat_channels (int): Number of channels for feature.
- out_channels (int): Number of channels for output.
- num_things_classes (int): Number of things.
- num_stuff_classes (int): Number of stuff.
- num_queries (int): Number of query in Transformer.
- pixel_decoder (:obj:`ConfigDict` or dict): Config for pixel
- decoder.
- enforce_decoder_input_project (bool): Whether to add a layer
- to change the embed_dim of transformer encoder in pixel decoder to
- the embed_dim of transformer decoder. Defaults to False.
- transformer_decoder (:obj:`ConfigDict` or dict): Config for
- transformer decoder.
- positional_encoding (:obj:`ConfigDict` or dict): Config for
- transformer decoder position encoding.
- loss_cls (:obj:`ConfigDict` or dict): Config of the classification
- loss. Defaults to `CrossEntropyLoss`.
- loss_mask (:obj:`ConfigDict` or dict): Config of the mask loss.
- Defaults to `FocalLoss`.
- loss_dice (:obj:`ConfigDict` or dict): Config of the dice loss.
- Defaults to `DiceLoss`.
- train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
- MaskFormer head.
- test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
- MaskFormer head.
- init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
- dict], optional): Initialization config dict. Defaults to None.
- """
- def __init__(self,
- in_channels: List[int],
- feat_channels: int,
- out_channels: int,
- num_things_classes: int = 80,
- num_stuff_classes: int = 53,
- num_queries: int = 100,
- pixel_decoder: ConfigType = ...,
- enforce_decoder_input_project: bool = False,
- transformer_decoder: ConfigType = ...,
- positional_encoding: ConfigType = dict(
- num_feats=128, normalize=True),
- loss_cls: ConfigType = dict(
- type='CrossEntropyLoss',
- use_sigmoid=False,
- loss_weight=1.0,
- class_weight=[1.0] * 133 + [0.1]),
- loss_mask: ConfigType = dict(
- type='FocalLoss',
- use_sigmoid=True,
- gamma=2.0,
- alpha=0.25,
- loss_weight=20.0),
- loss_dice: ConfigType = dict(
- type='DiceLoss',
- use_sigmoid=True,
- activate=True,
- naive_dice=True,
- loss_weight=1.0),
- train_cfg: OptConfigType = None,
- test_cfg: OptConfigType = None,
- init_cfg: OptMultiConfig = None,
- **kwargs) -> None:
- super(AnchorFreeHead, self).__init__(init_cfg=init_cfg)
- self.num_things_classes = num_things_classes
- self.num_stuff_classes = num_stuff_classes
- self.num_classes = self.num_things_classes + self.num_stuff_classes
- self.num_queries = num_queries
- pixel_decoder.update(
- in_channels=in_channels,
- feat_channels=feat_channels,
- out_channels=out_channels)
- self.pixel_decoder = MODELS.build(pixel_decoder)
- self.transformer_decoder = DetrTransformerDecoder(
- **transformer_decoder)
- self.decoder_embed_dims = self.transformer_decoder.embed_dims
- if type(self.pixel_decoder) == PixelDecoder and (
- self.decoder_embed_dims != in_channels[-1]
- or enforce_decoder_input_project):
- self.decoder_input_proj = Conv2d(
- in_channels[-1], self.decoder_embed_dims, kernel_size=1)
- else:
- self.decoder_input_proj = nn.Identity()
- self.decoder_pe = SinePositionalEncoding(**positional_encoding)
- self.query_embed = nn.Embedding(self.num_queries, out_channels)
- self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
- self.mask_embed = nn.Sequential(
- nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
- nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
- nn.Linear(feat_channels, out_channels))
- self.test_cfg = test_cfg
- self.train_cfg = train_cfg
- if train_cfg:
- self.assigner = TASK_UTILS.build(train_cfg['assigner'])
- self.sampler = TASK_UTILS.build(
- train_cfg['sampler'], default_args=dict(context=self))
- self.class_weight = loss_cls.class_weight
- self.loss_cls = MODELS.build(loss_cls)
- self.loss_mask = MODELS.build(loss_mask)
- self.loss_dice = MODELS.build(loss_dice)
- def init_weights(self) -> None:
- if isinstance(self.decoder_input_proj, Conv2d):
- caffe2_xavier_init(self.decoder_input_proj, bias=0)
- self.pixel_decoder.init_weights()
- for p in self.transformer_decoder.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform_(p)
- def preprocess_gt(
- self, batch_gt_instances: InstanceList,
- batch_gt_semantic_segs: List[Optional[PixelData]]) -> InstanceList:
- """Preprocess the ground truth for all images.
- Args:
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``labels``, each is
- ground truth labels of each bbox, with shape (num_gts, )
- and ``masks``, each is ground truth masks of each instances
- of a image, shape (num_gts, h, w).
- gt_semantic_seg (list[Optional[PixelData]]): Ground truth of
- semantic segmentation, each with the shape (1, h, w).
- [0, num_thing_class - 1] means things,
- [num_thing_class, num_class-1] means stuff,
- 255 means VOID. It's None when training instance segmentation.
- Returns:
- list[obj:`InstanceData`]: each contains the following keys
- - labels (Tensor): Ground truth class indices\
- for a image, with shape (n, ), n is the sum of\
- number of stuff type and number of instance in a image.
- - masks (Tensor): Ground truth mask for a\
- image, with shape (n, h, w).
- """
- num_things_list = [self.num_things_classes] * len(batch_gt_instances)
- num_stuff_list = [self.num_stuff_classes] * len(batch_gt_instances)
- gt_labels_list = [
- gt_instances['labels'] for gt_instances in batch_gt_instances
- ]
- gt_masks_list = [
- gt_instances['masks'] for gt_instances in batch_gt_instances
- ]
- gt_semantic_segs = [
- None if gt_semantic_seg is None else gt_semantic_seg.sem_seg
- for gt_semantic_seg in batch_gt_semantic_segs
- ]
- targets = multi_apply(preprocess_panoptic_gt, gt_labels_list,
- gt_masks_list, gt_semantic_segs, num_things_list,
- num_stuff_list)
- labels, masks = targets
- batch_gt_instances = [
- InstanceData(labels=label, masks=mask)
- for label, mask in zip(labels, masks)
- ]
- return batch_gt_instances
- def get_targets(
- self,
- cls_scores_list: List[Tensor],
- mask_preds_list: List[Tensor],
- batch_gt_instances: InstanceList,
- batch_img_metas: List[dict],
- return_sampling_results: bool = False
- ) -> Tuple[List[Union[Tensor, int]]]:
- """Compute classification and mask targets for all images for a decoder
- layer.
- Args:
- cls_scores_list (list[Tensor]): Mask score logits from a single
- decoder layer for all images. Each with shape (num_queries,
- cls_out_channels).
- mask_preds_list (list[Tensor]): Mask logits from a single decoder
- layer for all images. Each with shape (num_queries, h, w).
- batch_gt_instances (list[obj:`InstanceData`]): each contains
- ``labels`` and ``masks``.
- batch_img_metas (list[dict]): List of image meta information.
- return_sampling_results (bool): Whether to return the sampling
- results. Defaults to False.
- Returns:
- tuple: a tuple containing the following targets.
- - labels_list (list[Tensor]): Labels of all images.\
- Each with shape (num_queries, ).
- - label_weights_list (list[Tensor]): Label weights\
- of all images. Each with shape (num_queries, ).
- - mask_targets_list (list[Tensor]): Mask targets of\
- all images. Each with shape (num_queries, h, w).
- - mask_weights_list (list[Tensor]): Mask weights of\
- all images. Each with shape (num_queries, ).
- - avg_factor (int): Average factor that is used to average\
- the loss. When using sampling method, avg_factor is
- usually the sum of positive and negative priors. When
- using `MaskPseudoSampler`, `avg_factor` is usually equal
- to the number of positive priors.
- additional_returns: This function enables user-defined returns from
- `self._get_targets_single`. These returns are currently refined
- to properties at each feature map (i.e. having HxW dimension).
- The results will be concatenated after the end.
- """
- results = multi_apply(self._get_targets_single, cls_scores_list,
- mask_preds_list, batch_gt_instances,
- batch_img_metas)
- (labels_list, label_weights_list, mask_targets_list, mask_weights_list,
- pos_inds_list, neg_inds_list, sampling_results_list) = results[:7]
- rest_results = list(results[7:])
- avg_factor = sum(
- [results.avg_factor for results in sampling_results_list])
- res = (labels_list, label_weights_list, mask_targets_list,
- mask_weights_list, avg_factor)
- if return_sampling_results:
- res = res + (sampling_results_list)
- return res + tuple(rest_results)
- def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor,
- gt_instances: InstanceData,
- img_meta: dict) -> Tuple[Tensor]:
- """Compute classification and mask targets for one image.
- Args:
- cls_score (Tensor): Mask score logits from a single decoder layer
- for one image. Shape (num_queries, cls_out_channels).
- mask_pred (Tensor): Mask logits for a single decoder layer for one
- image. Shape (num_queries, h, w).
- gt_instances (:obj:`InstanceData`): It contains ``labels`` and
- ``masks``.
- img_meta (dict): Image informtation.
- Returns:
- tuple: a tuple containing the following for one image.
- - labels (Tensor): Labels of each image.
- shape (num_queries, ).
- - label_weights (Tensor): Label weights of each image.
- shape (num_queries, ).
- - mask_targets (Tensor): Mask targets of each image.
- shape (num_queries, h, w).
- - mask_weights (Tensor): Mask weights of each image.
- shape (num_queries, ).
- - pos_inds (Tensor): Sampled positive indices for each image.
- - neg_inds (Tensor): Sampled negative indices for each image.
- - sampling_result (:obj:`SamplingResult`): Sampling results.
- """
- gt_masks = gt_instances.masks
- gt_labels = gt_instances.labels
- target_shape = mask_pred.shape[-2:]
- if gt_masks.shape[0] > 0:
- gt_masks_downsampled = F.interpolate(
- gt_masks.unsqueeze(1).float(), target_shape,
- mode='nearest').squeeze(1).long()
- else:
- gt_masks_downsampled = gt_masks
- pred_instances = InstanceData(scores=cls_score, masks=mask_pred)
- downsampled_gt_instances = InstanceData(
- labels=gt_labels, masks=gt_masks_downsampled)
- # assign and sample
- assign_result = self.assigner.assign(
- pred_instances=pred_instances,
- gt_instances=downsampled_gt_instances,
- img_meta=img_meta)
- sampling_result = self.sampler.sample(
- assign_result=assign_result,
- pred_instances=pred_instances,
- gt_instances=gt_instances)
- pos_inds = sampling_result.pos_inds
- neg_inds = sampling_result.neg_inds
- # label target
- labels = gt_labels.new_full((self.num_queries, ),
- self.num_classes,
- dtype=torch.long)
- labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
- label_weights = gt_labels.new_ones(self.num_queries)
- # mask target
- mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds]
- mask_weights = mask_pred.new_zeros((self.num_queries, ))
- mask_weights[pos_inds] = 1.0
- return (labels, label_weights, mask_targets, mask_weights, pos_inds,
- neg_inds, sampling_result)
- def loss_by_feat(self, all_cls_scores: Tensor, all_mask_preds: Tensor,
- batch_gt_instances: List[InstanceData],
- batch_img_metas: List[dict]) -> Dict[str, Tensor]:
- """Loss function.
- Args:
- all_cls_scores (Tensor): Classification scores for all decoder
- layers with shape (num_decoder, batch_size, num_queries,
- cls_out_channels). Note `cls_out_channels` should includes
- background.
- all_mask_preds (Tensor): Mask scores for all decoder layers with
- shape (num_decoder, batch_size, num_queries, h, w).
- batch_gt_instances (list[obj:`InstanceData`]): each contains
- ``labels`` and ``masks``.
- batch_img_metas (list[dict]): List of image meta information.
- Returns:
- dict[str, Tensor]: A dictionary of loss components.
- """
- num_dec_layers = len(all_cls_scores)
- batch_gt_instances_list = [
- batch_gt_instances for _ in range(num_dec_layers)
- ]
- img_metas_list = [batch_img_metas for _ in range(num_dec_layers)]
- losses_cls, losses_mask, losses_dice = multi_apply(
- self._loss_by_feat_single, all_cls_scores, all_mask_preds,
- batch_gt_instances_list, img_metas_list)
- loss_dict = dict()
- # loss from the last decoder layer
- loss_dict['loss_cls'] = losses_cls[-1]
- loss_dict['loss_mask'] = losses_mask[-1]
- loss_dict['loss_dice'] = losses_dice[-1]
- # loss from other decoder layers
- num_dec_layer = 0
- for loss_cls_i, loss_mask_i, loss_dice_i in zip(
- losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]):
- loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
- loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i
- loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i
- num_dec_layer += 1
- return loss_dict
- def _loss_by_feat_single(self, cls_scores: Tensor, mask_preds: Tensor,
- batch_gt_instances: List[InstanceData],
- batch_img_metas: List[dict]) -> Tuple[Tensor]:
- """Loss function for outputs from a single decoder layer.
- Args:
- cls_scores (Tensor): Mask score logits from a single decoder layer
- for all images. Shape (batch_size, num_queries,
- cls_out_channels). Note `cls_out_channels` should includes
- background.
- mask_preds (Tensor): Mask logits for a pixel decoder for all
- images. Shape (batch_size, num_queries, h, w).
- batch_gt_instances (list[obj:`InstanceData`]): each contains
- ``labels`` and ``masks``.
- batch_img_metas (list[dict]): List of image meta information.
- Returns:
- tuple[Tensor]: Loss components for outputs from a single decoder\
- layer.
- """
- num_imgs = cls_scores.size(0)
- cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
- mask_preds_list = [mask_preds[i] for i in range(num_imgs)]
- (labels_list, label_weights_list, mask_targets_list, mask_weights_list,
- avg_factor) = self.get_targets(cls_scores_list, mask_preds_list,
- batch_gt_instances, batch_img_metas)
- # shape (batch_size, num_queries)
- labels = torch.stack(labels_list, dim=0)
- # shape (batch_size, num_queries)
- label_weights = torch.stack(label_weights_list, dim=0)
- # shape (num_total_gts, h, w)
- mask_targets = torch.cat(mask_targets_list, dim=0)
- # shape (batch_size, num_queries)
- mask_weights = torch.stack(mask_weights_list, dim=0)
- # classfication loss
- # shape (batch_size * num_queries, )
- cls_scores = cls_scores.flatten(0, 1)
- labels = labels.flatten(0, 1)
- label_weights = label_weights.flatten(0, 1)
- class_weight = cls_scores.new_tensor(self.class_weight)
- loss_cls = self.loss_cls(
- cls_scores,
- labels,
- label_weights,
- avg_factor=class_weight[labels].sum())
- num_total_masks = reduce_mean(cls_scores.new_tensor([avg_factor]))
- num_total_masks = max(num_total_masks, 1)
- # extract positive ones
- # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
- mask_preds = mask_preds[mask_weights > 0]
- target_shape = mask_targets.shape[-2:]
- if mask_targets.shape[0] == 0:
- # zero match
- loss_dice = mask_preds.sum()
- loss_mask = mask_preds.sum()
- return loss_cls, loss_mask, loss_dice
- # upsample to shape of target
- # shape (num_total_gts, h, w)
- mask_preds = F.interpolate(
- mask_preds.unsqueeze(1),
- target_shape,
- mode='bilinear',
- align_corners=False).squeeze(1)
- # dice loss
- loss_dice = self.loss_dice(
- mask_preds, mask_targets, avg_factor=num_total_masks)
- # mask loss
- # FocalLoss support input of shape (n, num_class)
- h, w = mask_preds.shape[-2:]
- # shape (num_total_gts, h, w) -> (num_total_gts * h * w, 1)
- mask_preds = mask_preds.reshape(-1, 1)
- # shape (num_total_gts, h, w) -> (num_total_gts * h * w)
- mask_targets = mask_targets.reshape(-1)
- # target is (1 - mask_targets) !!!
- loss_mask = self.loss_mask(
- mask_preds, 1 - mask_targets, avg_factor=num_total_masks * h * w)
- return loss_cls, loss_mask, loss_dice
- def forward(self, x: Tuple[Tensor],
- batch_data_samples: SampleList) -> Tuple[Tensor]:
- """Forward function.
- Args:
- x (tuple[Tensor]): Features from the upstream network, each
- is a 4D-tensor.
- batch_data_samples (List[:obj:`DetDataSample`]): The Data
- Samples. It usually includes information such as
- `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
- Returns:
- tuple[Tensor]: a tuple contains two elements.
- - all_cls_scores (Tensor): Classification scores for each\
- scale level. Each is a 4D-tensor with shape\
- (num_decoder, batch_size, num_queries, cls_out_channels).\
- Note `cls_out_channels` should includes background.
- - all_mask_preds (Tensor): Mask scores for each decoder\
- layer. Each with shape (num_decoder, batch_size,\
- num_queries, h, w).
- """
- batch_img_metas = [
- data_sample.metainfo for data_sample in batch_data_samples
- ]
- batch_size = len(batch_img_metas)
- input_img_h, input_img_w = batch_img_metas[0]['batch_input_shape']
- padding_mask = x[-1].new_ones((batch_size, input_img_h, input_img_w),
- dtype=torch.float32)
- for i in range(batch_size):
- img_h, img_w = batch_img_metas[i]['img_shape']
- padding_mask[i, :img_h, :img_w] = 0
- padding_mask = F.interpolate(
- padding_mask.unsqueeze(1), size=x[-1].shape[-2:],
- mode='nearest').to(torch.bool).squeeze(1)
- # when backbone is swin, memory is output of last stage of swin.
- # when backbone is r50, memory is output of tranformer encoder.
- mask_features, memory = self.pixel_decoder(x, batch_img_metas)
- pos_embed = self.decoder_pe(padding_mask)
- memory = self.decoder_input_proj(memory)
- # shape (batch_size, c, h, w) -> (batch_size, h*w, c)
- memory = memory.flatten(2).permute(0, 2, 1)
- pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
- # shape (batch_size, h * w)
- padding_mask = padding_mask.flatten(1)
- # shape = (num_queries, embed_dims)
- query_embed = self.query_embed.weight
- # shape = (batch_size, num_queries, embed_dims)
- query_embed = query_embed.unsqueeze(0).repeat(batch_size, 1, 1)
- target = torch.zeros_like(query_embed)
- # shape (num_decoder, num_queries, batch_size, embed_dims)
- out_dec = self.transformer_decoder(
- query=target,
- key=memory,
- value=memory,
- query_pos=query_embed,
- key_pos=pos_embed,
- key_padding_mask=padding_mask)
- # cls_scores
- all_cls_scores = self.cls_embed(out_dec)
- # mask_preds
- mask_embed = self.mask_embed(out_dec)
- all_mask_preds = torch.einsum('lbqc,bchw->lbqhw', mask_embed,
- mask_features)
- return all_cls_scores, all_mask_preds
- def loss(
- self,
- x: Tuple[Tensor],
- batch_data_samples: SampleList,
- ) -> Dict[str, Tensor]:
- """Perform forward propagation and loss calculation of the panoptic
- head on the features of the upstream network.
- Args:
- x (tuple[Tensor]): Multi-level features from the upstream
- network, each is a 4D-tensor.
- batch_data_samples (List[:obj:`DetDataSample`]): The Data
- Samples. It usually includes information such as
- `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
- Returns:
- dict[str, Tensor]: a dictionary of loss components
- """
- batch_img_metas = []
- batch_gt_instances = []
- batch_gt_semantic_segs = []
- for data_sample in batch_data_samples:
- batch_img_metas.append(data_sample.metainfo)
- batch_gt_instances.append(data_sample.gt_instances)
- if 'gt_sem_seg' in data_sample:
- batch_gt_semantic_segs.append(data_sample.gt_sem_seg)
- else:
- batch_gt_semantic_segs.append(None)
- # forward
- all_cls_scores, all_mask_preds = self(x, batch_data_samples)
- # preprocess ground truth
- batch_gt_instances = self.preprocess_gt(batch_gt_instances,
- batch_gt_semantic_segs)
- # loss
- losses = self.loss_by_feat(all_cls_scores, all_mask_preds,
- batch_gt_instances, batch_img_metas)
- return losses
- def predict(self, x: Tuple[Tensor],
- batch_data_samples: SampleList) -> Tuple[Tensor]:
- """Test without augmentaton.
- Args:
- x (tuple[Tensor]): Multi-level features from the
- upstream network, each is a 4D-tensor.
- batch_data_samples (List[:obj:`DetDataSample`]): The Data
- Samples. It usually includes information such as
- `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
- Returns:
- tuple[Tensor]: A tuple contains two tensors.
- - mask_cls_results (Tensor): Mask classification logits,\
- shape (batch_size, num_queries, cls_out_channels).
- Note `cls_out_channels` should includes background.
- - mask_pred_results (Tensor): Mask logits, shape \
- (batch_size, num_queries, h, w).
- """
- batch_img_metas = [
- data_sample.metainfo for data_sample in batch_data_samples
- ]
- all_cls_scores, all_mask_preds = self(x, batch_data_samples)
- mask_cls_results = all_cls_scores[-1]
- mask_pred_results = all_mask_preds[-1]
- # upsample masks
- img_shape = batch_img_metas[0]['batch_input_shape']
- mask_pred_results = F.interpolate(
- mask_pred_results,
- size=(img_shape[0], img_shape[1]),
- mode='bilinear',
- align_corners=False)
- return mask_cls_results, mask_pred_results
|