123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import copy
- from typing import List, Tuple
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from mmcv.cnn import Conv2d
- from mmcv.ops import point_sample
- from mmengine.model import ModuleList, caffe2_xavier_init
- from mmengine.structures import InstanceData
- from torch import Tensor
- from mmdet.registry import MODELS, TASK_UTILS
- from mmdet.structures import SampleList
- from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig, reduce_mean
- from ..layers import Mask2FormerTransformerDecoder, SinePositionalEncoding
- from ..utils import get_uncertain_point_coords_with_randomness
- from .anchor_free_head import AnchorFreeHead
- from .maskformer_head import MaskFormerHead
- @MODELS.register_module()
- class Mask2FormerHead(MaskFormerHead):
- """Implements the Mask2Former head.
- See `Masked-attention Mask Transformer for Universal Image
- Segmentation <https://arxiv.org/pdf/2112.01527>`_ for details.
- Args:
- in_channels (list[int]): Number of channels in the input feature map.
- feat_channels (int): Number of channels for features.
- out_channels (int): Number of channels for output.
- num_things_classes (int): Number of things.
- num_stuff_classes (int): Number of stuff.
- num_queries (int): Number of query in Transformer decoder.
- pixel_decoder (:obj:`ConfigDict` or dict): Config for pixel
- decoder. Defaults to None.
- enforce_decoder_input_project (bool, optional): Whether to add
- a layer to change the embed_dim of tranformer encoder in
- pixel decoder to the embed_dim of transformer decoder.
- Defaults to False.
- transformer_decoder (:obj:`ConfigDict` or dict): Config for
- transformer decoder. Defaults to None.
- positional_encoding (:obj:`ConfigDict` or dict): Config for
- transformer decoder position encoding. Defaults to
- dict(num_feats=128, normalize=True).
- loss_cls (:obj:`ConfigDict` or dict): Config of the classification
- loss. Defaults to None.
- loss_mask (:obj:`ConfigDict` or dict): Config of the mask loss.
- Defaults to None.
- loss_dice (:obj:`ConfigDict` or dict): Config of the dice loss.
- Defaults to None.
- train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
- Mask2Former head.
- test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
- Mask2Former head.
- init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
- dict], optional): Initialization config dict. Defaults to None.
- """
- def __init__(self,
- in_channels: List[int],
- feat_channels: int,
- out_channels: int,
- num_things_classes: int = 80,
- num_stuff_classes: int = 53,
- num_queries: int = 100,
- num_transformer_feat_level: int = 3,
- pixel_decoder: ConfigType = ...,
- enforce_decoder_input_project: bool = False,
- transformer_decoder: ConfigType = ...,
- positional_encoding: ConfigType = dict(
- num_feats=128, normalize=True),
- loss_cls: ConfigType = dict(
- type='CrossEntropyLoss',
- use_sigmoid=False,
- loss_weight=2.0,
- reduction='mean',
- class_weight=[1.0] * 133 + [0.1]),
- loss_mask: ConfigType = dict(
- type='CrossEntropyLoss',
- use_sigmoid=True,
- reduction='mean',
- loss_weight=5.0),
- loss_dice: ConfigType = dict(
- type='DiceLoss',
- use_sigmoid=True,
- activate=True,
- reduction='mean',
- naive_dice=True,
- eps=1.0,
- loss_weight=5.0),
- train_cfg: OptConfigType = None,
- test_cfg: OptConfigType = None,
- init_cfg: OptMultiConfig = None,
- **kwargs) -> None:
- super(AnchorFreeHead, self).__init__(init_cfg=init_cfg)
- self.num_things_classes = num_things_classes
- self.num_stuff_classes = num_stuff_classes
- self.num_classes = self.num_things_classes + self.num_stuff_classes
- self.num_queries = num_queries
- self.num_transformer_feat_level = num_transformer_feat_level
- self.num_heads = transformer_decoder.layer_cfg.cross_attn_cfg.num_heads
- self.num_transformer_decoder_layers = transformer_decoder.num_layers
- assert pixel_decoder.encoder.layer_cfg. \
- self_attn_cfg.num_levels == num_transformer_feat_level
- pixel_decoder_ = copy.deepcopy(pixel_decoder)
- pixel_decoder_.update(
- in_channels=in_channels,
- feat_channels=feat_channels,
- out_channels=out_channels)
- self.pixel_decoder = MODELS.build(pixel_decoder_)
- self.transformer_decoder = Mask2FormerTransformerDecoder(
- **transformer_decoder)
- self.decoder_embed_dims = self.transformer_decoder.embed_dims
- self.decoder_input_projs = ModuleList()
- # from low resolution to high resolution
- for _ in range(num_transformer_feat_level):
- if (self.decoder_embed_dims != feat_channels
- or enforce_decoder_input_project):
- self.decoder_input_projs.append(
- Conv2d(
- feat_channels, self.decoder_embed_dims, kernel_size=1))
- else:
- self.decoder_input_projs.append(nn.Identity())
- self.decoder_positional_encoding = SinePositionalEncoding(
- **positional_encoding)
- self.query_embed = nn.Embedding(self.num_queries, feat_channels)
- self.query_feat = nn.Embedding(self.num_queries, feat_channels)
- # from low resolution to high resolution
- self.level_embed = nn.Embedding(self.num_transformer_feat_level,
- feat_channels)
- self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
- self.mask_embed = nn.Sequential(
- nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
- nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
- nn.Linear(feat_channels, out_channels))
- self.test_cfg = test_cfg
- self.train_cfg = train_cfg
- if train_cfg:
- self.assigner = TASK_UTILS.build(self.train_cfg['assigner'])
- self.sampler = TASK_UTILS.build(
- self.train_cfg['sampler'], default_args=dict(context=self))
- self.num_points = self.train_cfg.get('num_points', 12544)
- self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0)
- self.importance_sample_ratio = self.train_cfg.get(
- 'importance_sample_ratio', 0.75)
- self.class_weight = loss_cls.class_weight
- self.loss_cls = MODELS.build(loss_cls)
- self.loss_mask = MODELS.build(loss_mask)
- self.loss_dice = MODELS.build(loss_dice)
- def init_weights(self) -> None:
- for m in self.decoder_input_projs:
- if isinstance(m, Conv2d):
- caffe2_xavier_init(m, bias=0)
- self.pixel_decoder.init_weights()
- for p in self.transformer_decoder.parameters():
- if p.dim() > 1:
- nn.init.xavier_normal_(p)
- def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor,
- gt_instances: InstanceData,
- img_meta: dict) -> Tuple[Tensor]:
- """Compute classification and mask targets for one image.
- Args:
- cls_score (Tensor): Mask score logits from a single decoder layer
- for one image. Shape (num_queries, cls_out_channels).
- mask_pred (Tensor): Mask logits for a single decoder layer for one
- image. Shape (num_queries, h, w).
- gt_instances (:obj:`InstanceData`): It contains ``labels`` and
- ``masks``.
- img_meta (dict): Image informtation.
- Returns:
- tuple[Tensor]: A tuple containing the following for one image.
- - labels (Tensor): Labels of each image. \
- shape (num_queries, ).
- - label_weights (Tensor): Label weights of each image. \
- shape (num_queries, ).
- - mask_targets (Tensor): Mask targets of each image. \
- shape (num_queries, h, w).
- - mask_weights (Tensor): Mask weights of each image. \
- shape (num_queries, ).
- - pos_inds (Tensor): Sampled positive indices for each \
- image.
- - neg_inds (Tensor): Sampled negative indices for each \
- image.
- - sampling_result (:obj:`SamplingResult`): Sampling results.
- """
- gt_labels = gt_instances.labels
- gt_masks = gt_instances.masks
- # sample points
- num_queries = cls_score.shape[0]
- num_gts = gt_labels.shape[0]
- point_coords = torch.rand((1, self.num_points, 2),
- device=cls_score.device)
- # shape (num_queries, num_points)
- mask_points_pred = point_sample(
- mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1,
- 1)).squeeze(1)
- # shape (num_gts, num_points)
- gt_points_masks = point_sample(
- gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1,
- 1)).squeeze(1)
- sampled_gt_instances = InstanceData(
- labels=gt_labels, masks=gt_points_masks)
- sampled_pred_instances = InstanceData(
- scores=cls_score, masks=mask_points_pred)
- # assign and sample
- assign_result = self.assigner.assign(
- pred_instances=sampled_pred_instances,
- gt_instances=sampled_gt_instances,
- img_meta=img_meta)
- pred_instances = InstanceData(scores=cls_score, masks=mask_pred)
- sampling_result = self.sampler.sample(
- assign_result=assign_result,
- pred_instances=pred_instances,
- gt_instances=gt_instances)
- pos_inds = sampling_result.pos_inds
- neg_inds = sampling_result.neg_inds
- # label target
- labels = gt_labels.new_full((self.num_queries, ),
- self.num_classes,
- dtype=torch.long)
- labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
- label_weights = gt_labels.new_ones((self.num_queries, ))
- # mask target
- mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds]
- mask_weights = mask_pred.new_zeros((self.num_queries, ))
- mask_weights[pos_inds] = 1.0
- return (labels, label_weights, mask_targets, mask_weights, pos_inds,
- neg_inds, sampling_result)
- def _loss_by_feat_single(self, cls_scores: Tensor, mask_preds: Tensor,
- batch_gt_instances: List[InstanceData],
- batch_img_metas: List[dict]) -> Tuple[Tensor]:
- """Loss function for outputs from a single decoder layer.
- Args:
- cls_scores (Tensor): Mask score logits from a single decoder layer
- for all images. Shape (batch_size, num_queries,
- cls_out_channels). Note `cls_out_channels` should includes
- background.
- mask_preds (Tensor): Mask logits for a pixel decoder for all
- images. Shape (batch_size, num_queries, h, w).
- batch_gt_instances (list[obj:`InstanceData`]): each contains
- ``labels`` and ``masks``.
- batch_img_metas (list[dict]): List of image meta information.
- Returns:
- tuple[Tensor]: Loss components for outputs from a single \
- decoder layer.
- """
- num_imgs = cls_scores.size(0)
- cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
- mask_preds_list = [mask_preds[i] for i in range(num_imgs)]
- (labels_list, label_weights_list, mask_targets_list, mask_weights_list,
- avg_factor) = self.get_targets(cls_scores_list, mask_preds_list,
- batch_gt_instances, batch_img_metas)
- # shape (batch_size, num_queries)
- labels = torch.stack(labels_list, dim=0)
- # shape (batch_size, num_queries)
- label_weights = torch.stack(label_weights_list, dim=0)
- # shape (num_total_gts, h, w)
- mask_targets = torch.cat(mask_targets_list, dim=0)
- # shape (batch_size, num_queries)
- mask_weights = torch.stack(mask_weights_list, dim=0)
- # classfication loss
- # shape (batch_size * num_queries, )
- cls_scores = cls_scores.flatten(0, 1)
- labels = labels.flatten(0, 1)
- label_weights = label_weights.flatten(0, 1)
- class_weight = cls_scores.new_tensor(self.class_weight)
- loss_cls = self.loss_cls(
- cls_scores,
- labels,
- label_weights,
- avg_factor=class_weight[labels].sum())
- num_total_masks = reduce_mean(cls_scores.new_tensor([avg_factor]))
- num_total_masks = max(num_total_masks, 1)
- # extract positive ones
- # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
- mask_preds = mask_preds[mask_weights > 0]
- if mask_targets.shape[0] == 0:
- # zero match
- loss_dice = mask_preds.sum()
- loss_mask = mask_preds.sum()
- return loss_cls, loss_mask, loss_dice
- with torch.no_grad():
- points_coords = get_uncertain_point_coords_with_randomness(
- mask_preds.unsqueeze(1), None, self.num_points,
- self.oversample_ratio, self.importance_sample_ratio)
- # shape (num_total_gts, h, w) -> (num_total_gts, num_points)
- mask_point_targets = point_sample(
- mask_targets.unsqueeze(1).float(), points_coords).squeeze(1)
- # shape (num_queries, h, w) -> (num_queries, num_points)
- mask_point_preds = point_sample(
- mask_preds.unsqueeze(1), points_coords).squeeze(1)
- # dice loss
- loss_dice = self.loss_dice(
- mask_point_preds, mask_point_targets, avg_factor=num_total_masks)
- # mask loss
- # shape (num_queries, num_points) -> (num_queries * num_points, )
- mask_point_preds = mask_point_preds.reshape(-1)
- # shape (num_total_gts, num_points) -> (num_total_gts * num_points, )
- mask_point_targets = mask_point_targets.reshape(-1)
- loss_mask = self.loss_mask(
- mask_point_preds,
- mask_point_targets,
- avg_factor=num_total_masks * self.num_points)
- return loss_cls, loss_mask, loss_dice
- def _forward_head(self, decoder_out: Tensor, mask_feature: Tensor,
- attn_mask_target_size: Tuple[int, int]) -> Tuple[Tensor]:
- """Forward for head part which is called after every decoder layer.
- Args:
- decoder_out (Tensor): in shape (batch_size, num_queries, c).
- mask_feature (Tensor): in shape (batch_size, c, h, w).
- attn_mask_target_size (tuple[int, int]): target attention
- mask size.
- Returns:
- tuple: A tuple contain three elements.
- - cls_pred (Tensor): Classification scores in shape \
- (batch_size, num_queries, cls_out_channels). \
- Note `cls_out_channels` should includes background.
- - mask_pred (Tensor): Mask scores in shape \
- (batch_size, num_queries,h, w).
- - attn_mask (Tensor): Attention mask in shape \
- (batch_size * num_heads, num_queries, h, w).
- """
- decoder_out = self.transformer_decoder.post_norm(decoder_out)
- # shape (num_queries, batch_size, c)
- cls_pred = self.cls_embed(decoder_out)
- # shape (num_queries, batch_size, c)
- mask_embed = self.mask_embed(decoder_out)
- # shape (num_queries, batch_size, h, w)
- mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature)
- attn_mask = F.interpolate(
- mask_pred,
- attn_mask_target_size,
- mode='bilinear',
- align_corners=False)
- # shape (num_queries, batch_size, h, w) ->
- # (batch_size * num_head, num_queries, h, w)
- attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat(
- (1, self.num_heads, 1, 1)).flatten(0, 1)
- attn_mask = attn_mask.sigmoid() < 0.5
- attn_mask = attn_mask.detach()
- return cls_pred, mask_pred, attn_mask
- def forward(self, x: List[Tensor],
- batch_data_samples: SampleList) -> Tuple[List[Tensor]]:
- """Forward function.
- Args:
- x (list[Tensor]): Multi scale Features from the
- upstream network, each is a 4D-tensor.
- batch_data_samples (List[:obj:`DetDataSample`]): The Data
- Samples. It usually includes information such as
- `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
- Returns:
- tuple[list[Tensor]]: A tuple contains two elements.
- - cls_pred_list (list[Tensor)]: Classification logits \
- for each decoder layer. Each is a 3D-tensor with shape \
- (batch_size, num_queries, cls_out_channels). \
- Note `cls_out_channels` should includes background.
- - mask_pred_list (list[Tensor]): Mask logits for each \
- decoder layer. Each with shape (batch_size, num_queries, \
- h, w).
- """
- batch_img_metas = [
- data_sample.metainfo for data_sample in batch_data_samples
- ]
- batch_size = len(batch_img_metas)
- mask_features, multi_scale_memorys = self.pixel_decoder(x)
- # multi_scale_memorys (from low resolution to high resolution)
- decoder_inputs = []
- decoder_positional_encodings = []
- for i in range(self.num_transformer_feat_level):
- decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i])
- # shape (batch_size, c, h, w) -> (batch_size, h*w, c)
- decoder_input = decoder_input.flatten(2).permute(0, 2, 1)
- level_embed = self.level_embed.weight[i].view(1, 1, -1)
- decoder_input = decoder_input + level_embed
- # shape (batch_size, c, h, w) -> (batch_size, h*w, c)
- mask = decoder_input.new_zeros(
- (batch_size, ) + multi_scale_memorys[i].shape[-2:],
- dtype=torch.bool)
- decoder_positional_encoding = self.decoder_positional_encoding(
- mask)
- decoder_positional_encoding = decoder_positional_encoding.flatten(
- 2).permute(0, 2, 1)
- decoder_inputs.append(decoder_input)
- decoder_positional_encodings.append(decoder_positional_encoding)
- # shape (num_queries, c) -> (batch_size, num_queries, c)
- query_feat = self.query_feat.weight.unsqueeze(0).repeat(
- (batch_size, 1, 1))
- query_embed = self.query_embed.weight.unsqueeze(0).repeat(
- (batch_size, 1, 1))
- cls_pred_list = []
- mask_pred_list = []
- cls_pred, mask_pred, attn_mask = self._forward_head(
- query_feat, mask_features, multi_scale_memorys[0].shape[-2:])
- cls_pred_list.append(cls_pred)
- mask_pred_list.append(mask_pred)
- for i in range(self.num_transformer_decoder_layers):
- level_idx = i % self.num_transformer_feat_level
- # if a mask is all True(all background), then set it all False.
- attn_mask[torch.where(
- attn_mask.sum(-1) == attn_mask.shape[-1])] = False
- # cross_attn + self_attn
- layer = self.transformer_decoder.layers[i]
- query_feat = layer(
- query=query_feat,
- key=decoder_inputs[level_idx],
- value=decoder_inputs[level_idx],
- query_pos=query_embed,
- key_pos=decoder_positional_encodings[level_idx],
- cross_attn_mask=attn_mask,
- query_key_padding_mask=None,
- # here we do not apply masking on padded region
- key_padding_mask=None)
- cls_pred, mask_pred, attn_mask = self._forward_head(
- query_feat, mask_features, multi_scale_memorys[
- (i + 1) % self.num_transformer_feat_level].shape[-2:])
- cls_pred_list.append(cls_pred)
- mask_pred_list.append(mask_pred)
- return cls_pred_list, mask_pred_list
|