maskformer_head.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, List, Optional, Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmcv.cnn import Conv2d
  7. from mmengine.model import caffe2_xavier_init
  8. from mmengine.structures import InstanceData, PixelData
  9. from torch import Tensor
  10. from mmdet.models.layers.pixel_decoder import PixelDecoder
  11. from mmdet.registry import MODELS, TASK_UTILS
  12. from mmdet.structures import SampleList
  13. from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
  14. OptMultiConfig, reduce_mean)
  15. from ..layers import DetrTransformerDecoder, SinePositionalEncoding
  16. from ..utils import multi_apply, preprocess_panoptic_gt
  17. from .anchor_free_head import AnchorFreeHead
  18. @MODELS.register_module()
  19. class MaskFormerHead(AnchorFreeHead):
  20. """Implements the MaskFormer head.
  21. See `Per-Pixel Classification is Not All You Need for Semantic
  22. Segmentation <https://arxiv.org/pdf/2107.06278>`_ for details.
  23. Args:
  24. in_channels (list[int]): Number of channels in the input feature map.
  25. feat_channels (int): Number of channels for feature.
  26. out_channels (int): Number of channels for output.
  27. num_things_classes (int): Number of things.
  28. num_stuff_classes (int): Number of stuff.
  29. num_queries (int): Number of query in Transformer.
  30. pixel_decoder (:obj:`ConfigDict` or dict): Config for pixel
  31. decoder.
  32. enforce_decoder_input_project (bool): Whether to add a layer
  33. to change the embed_dim of transformer encoder in pixel decoder to
  34. the embed_dim of transformer decoder. Defaults to False.
  35. transformer_decoder (:obj:`ConfigDict` or dict): Config for
  36. transformer decoder.
  37. positional_encoding (:obj:`ConfigDict` or dict): Config for
  38. transformer decoder position encoding.
  39. loss_cls (:obj:`ConfigDict` or dict): Config of the classification
  40. loss. Defaults to `CrossEntropyLoss`.
  41. loss_mask (:obj:`ConfigDict` or dict): Config of the mask loss.
  42. Defaults to `FocalLoss`.
  43. loss_dice (:obj:`ConfigDict` or dict): Config of the dice loss.
  44. Defaults to `DiceLoss`.
  45. train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
  46. MaskFormer head.
  47. test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
  48. MaskFormer head.
  49. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  50. dict], optional): Initialization config dict. Defaults to None.
  51. """
  52. def __init__(self,
  53. in_channels: List[int],
  54. feat_channels: int,
  55. out_channels: int,
  56. num_things_classes: int = 80,
  57. num_stuff_classes: int = 53,
  58. num_queries: int = 100,
  59. pixel_decoder: ConfigType = ...,
  60. enforce_decoder_input_project: bool = False,
  61. transformer_decoder: ConfigType = ...,
  62. positional_encoding: ConfigType = dict(
  63. num_feats=128, normalize=True),
  64. loss_cls: ConfigType = dict(
  65. type='CrossEntropyLoss',
  66. use_sigmoid=False,
  67. loss_weight=1.0,
  68. class_weight=[1.0] * 133 + [0.1]),
  69. loss_mask: ConfigType = dict(
  70. type='FocalLoss',
  71. use_sigmoid=True,
  72. gamma=2.0,
  73. alpha=0.25,
  74. loss_weight=20.0),
  75. loss_dice: ConfigType = dict(
  76. type='DiceLoss',
  77. use_sigmoid=True,
  78. activate=True,
  79. naive_dice=True,
  80. loss_weight=1.0),
  81. train_cfg: OptConfigType = None,
  82. test_cfg: OptConfigType = None,
  83. init_cfg: OptMultiConfig = None,
  84. **kwargs) -> None:
  85. super(AnchorFreeHead, self).__init__(init_cfg=init_cfg)
  86. self.num_things_classes = num_things_classes
  87. self.num_stuff_classes = num_stuff_classes
  88. self.num_classes = self.num_things_classes + self.num_stuff_classes
  89. self.num_queries = num_queries
  90. pixel_decoder.update(
  91. in_channels=in_channels,
  92. feat_channels=feat_channels,
  93. out_channels=out_channels)
  94. self.pixel_decoder = MODELS.build(pixel_decoder)
  95. self.transformer_decoder = DetrTransformerDecoder(
  96. **transformer_decoder)
  97. self.decoder_embed_dims = self.transformer_decoder.embed_dims
  98. if type(self.pixel_decoder) == PixelDecoder and (
  99. self.decoder_embed_dims != in_channels[-1]
  100. or enforce_decoder_input_project):
  101. self.decoder_input_proj = Conv2d(
  102. in_channels[-1], self.decoder_embed_dims, kernel_size=1)
  103. else:
  104. self.decoder_input_proj = nn.Identity()
  105. self.decoder_pe = SinePositionalEncoding(**positional_encoding)
  106. self.query_embed = nn.Embedding(self.num_queries, out_channels)
  107. self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
  108. self.mask_embed = nn.Sequential(
  109. nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
  110. nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
  111. nn.Linear(feat_channels, out_channels))
  112. self.test_cfg = test_cfg
  113. self.train_cfg = train_cfg
  114. if train_cfg:
  115. self.assigner = TASK_UTILS.build(train_cfg['assigner'])
  116. self.sampler = TASK_UTILS.build(
  117. train_cfg['sampler'], default_args=dict(context=self))
  118. self.class_weight = loss_cls.class_weight
  119. self.loss_cls = MODELS.build(loss_cls)
  120. self.loss_mask = MODELS.build(loss_mask)
  121. self.loss_dice = MODELS.build(loss_dice)
  122. def init_weights(self) -> None:
  123. if isinstance(self.decoder_input_proj, Conv2d):
  124. caffe2_xavier_init(self.decoder_input_proj, bias=0)
  125. self.pixel_decoder.init_weights()
  126. for p in self.transformer_decoder.parameters():
  127. if p.dim() > 1:
  128. nn.init.xavier_uniform_(p)
  129. def preprocess_gt(
  130. self, batch_gt_instances: InstanceList,
  131. batch_gt_semantic_segs: List[Optional[PixelData]]) -> InstanceList:
  132. """Preprocess the ground truth for all images.
  133. Args:
  134. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  135. gt_instance. It usually includes ``labels``, each is
  136. ground truth labels of each bbox, with shape (num_gts, )
  137. and ``masks``, each is ground truth masks of each instances
  138. of a image, shape (num_gts, h, w).
  139. gt_semantic_seg (list[Optional[PixelData]]): Ground truth of
  140. semantic segmentation, each with the shape (1, h, w).
  141. [0, num_thing_class - 1] means things,
  142. [num_thing_class, num_class-1] means stuff,
  143. 255 means VOID. It's None when training instance segmentation.
  144. Returns:
  145. list[obj:`InstanceData`]: each contains the following keys
  146. - labels (Tensor): Ground truth class indices\
  147. for a image, with shape (n, ), n is the sum of\
  148. number of stuff type and number of instance in a image.
  149. - masks (Tensor): Ground truth mask for a\
  150. image, with shape (n, h, w).
  151. """
  152. num_things_list = [self.num_things_classes] * len(batch_gt_instances)
  153. num_stuff_list = [self.num_stuff_classes] * len(batch_gt_instances)
  154. gt_labels_list = [
  155. gt_instances['labels'] for gt_instances in batch_gt_instances
  156. ]
  157. gt_masks_list = [
  158. gt_instances['masks'] for gt_instances in batch_gt_instances
  159. ]
  160. gt_semantic_segs = [
  161. None if gt_semantic_seg is None else gt_semantic_seg.sem_seg
  162. for gt_semantic_seg in batch_gt_semantic_segs
  163. ]
  164. targets = multi_apply(preprocess_panoptic_gt, gt_labels_list,
  165. gt_masks_list, gt_semantic_segs, num_things_list,
  166. num_stuff_list)
  167. labels, masks = targets
  168. batch_gt_instances = [
  169. InstanceData(labels=label, masks=mask)
  170. for label, mask in zip(labels, masks)
  171. ]
  172. return batch_gt_instances
  173. def get_targets(
  174. self,
  175. cls_scores_list: List[Tensor],
  176. mask_preds_list: List[Tensor],
  177. batch_gt_instances: InstanceList,
  178. batch_img_metas: List[dict],
  179. return_sampling_results: bool = False
  180. ) -> Tuple[List[Union[Tensor, int]]]:
  181. """Compute classification and mask targets for all images for a decoder
  182. layer.
  183. Args:
  184. cls_scores_list (list[Tensor]): Mask score logits from a single
  185. decoder layer for all images. Each with shape (num_queries,
  186. cls_out_channels).
  187. mask_preds_list (list[Tensor]): Mask logits from a single decoder
  188. layer for all images. Each with shape (num_queries, h, w).
  189. batch_gt_instances (list[obj:`InstanceData`]): each contains
  190. ``labels`` and ``masks``.
  191. batch_img_metas (list[dict]): List of image meta information.
  192. return_sampling_results (bool): Whether to return the sampling
  193. results. Defaults to False.
  194. Returns:
  195. tuple: a tuple containing the following targets.
  196. - labels_list (list[Tensor]): Labels of all images.\
  197. Each with shape (num_queries, ).
  198. - label_weights_list (list[Tensor]): Label weights\
  199. of all images. Each with shape (num_queries, ).
  200. - mask_targets_list (list[Tensor]): Mask targets of\
  201. all images. Each with shape (num_queries, h, w).
  202. - mask_weights_list (list[Tensor]): Mask weights of\
  203. all images. Each with shape (num_queries, ).
  204. - avg_factor (int): Average factor that is used to average\
  205. the loss. When using sampling method, avg_factor is
  206. usually the sum of positive and negative priors. When
  207. using `MaskPseudoSampler`, `avg_factor` is usually equal
  208. to the number of positive priors.
  209. additional_returns: This function enables user-defined returns from
  210. `self._get_targets_single`. These returns are currently refined
  211. to properties at each feature map (i.e. having HxW dimension).
  212. The results will be concatenated after the end.
  213. """
  214. results = multi_apply(self._get_targets_single, cls_scores_list,
  215. mask_preds_list, batch_gt_instances,
  216. batch_img_metas)
  217. (labels_list, label_weights_list, mask_targets_list, mask_weights_list,
  218. pos_inds_list, neg_inds_list, sampling_results_list) = results[:7]
  219. rest_results = list(results[7:])
  220. avg_factor = sum(
  221. [results.avg_factor for results in sampling_results_list])
  222. res = (labels_list, label_weights_list, mask_targets_list,
  223. mask_weights_list, avg_factor)
  224. if return_sampling_results:
  225. res = res + (sampling_results_list)
  226. return res + tuple(rest_results)
  227. def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor,
  228. gt_instances: InstanceData,
  229. img_meta: dict) -> Tuple[Tensor]:
  230. """Compute classification and mask targets for one image.
  231. Args:
  232. cls_score (Tensor): Mask score logits from a single decoder layer
  233. for one image. Shape (num_queries, cls_out_channels).
  234. mask_pred (Tensor): Mask logits for a single decoder layer for one
  235. image. Shape (num_queries, h, w).
  236. gt_instances (:obj:`InstanceData`): It contains ``labels`` and
  237. ``masks``.
  238. img_meta (dict): Image informtation.
  239. Returns:
  240. tuple: a tuple containing the following for one image.
  241. - labels (Tensor): Labels of each image.
  242. shape (num_queries, ).
  243. - label_weights (Tensor): Label weights of each image.
  244. shape (num_queries, ).
  245. - mask_targets (Tensor): Mask targets of each image.
  246. shape (num_queries, h, w).
  247. - mask_weights (Tensor): Mask weights of each image.
  248. shape (num_queries, ).
  249. - pos_inds (Tensor): Sampled positive indices for each image.
  250. - neg_inds (Tensor): Sampled negative indices for each image.
  251. - sampling_result (:obj:`SamplingResult`): Sampling results.
  252. """
  253. gt_masks = gt_instances.masks
  254. gt_labels = gt_instances.labels
  255. target_shape = mask_pred.shape[-2:]
  256. if gt_masks.shape[0] > 0:
  257. gt_masks_downsampled = F.interpolate(
  258. gt_masks.unsqueeze(1).float(), target_shape,
  259. mode='nearest').squeeze(1).long()
  260. else:
  261. gt_masks_downsampled = gt_masks
  262. pred_instances = InstanceData(scores=cls_score, masks=mask_pred)
  263. downsampled_gt_instances = InstanceData(
  264. labels=gt_labels, masks=gt_masks_downsampled)
  265. # assign and sample
  266. assign_result = self.assigner.assign(
  267. pred_instances=pred_instances,
  268. gt_instances=downsampled_gt_instances,
  269. img_meta=img_meta)
  270. sampling_result = self.sampler.sample(
  271. assign_result=assign_result,
  272. pred_instances=pred_instances,
  273. gt_instances=gt_instances)
  274. pos_inds = sampling_result.pos_inds
  275. neg_inds = sampling_result.neg_inds
  276. # label target
  277. labels = gt_labels.new_full((self.num_queries, ),
  278. self.num_classes,
  279. dtype=torch.long)
  280. labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
  281. label_weights = gt_labels.new_ones(self.num_queries)
  282. # mask target
  283. mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds]
  284. mask_weights = mask_pred.new_zeros((self.num_queries, ))
  285. mask_weights[pos_inds] = 1.0
  286. return (labels, label_weights, mask_targets, mask_weights, pos_inds,
  287. neg_inds, sampling_result)
  288. def loss_by_feat(self, all_cls_scores: Tensor, all_mask_preds: Tensor,
  289. batch_gt_instances: List[InstanceData],
  290. batch_img_metas: List[dict]) -> Dict[str, Tensor]:
  291. """Loss function.
  292. Args:
  293. all_cls_scores (Tensor): Classification scores for all decoder
  294. layers with shape (num_decoder, batch_size, num_queries,
  295. cls_out_channels). Note `cls_out_channels` should includes
  296. background.
  297. all_mask_preds (Tensor): Mask scores for all decoder layers with
  298. shape (num_decoder, batch_size, num_queries, h, w).
  299. batch_gt_instances (list[obj:`InstanceData`]): each contains
  300. ``labels`` and ``masks``.
  301. batch_img_metas (list[dict]): List of image meta information.
  302. Returns:
  303. dict[str, Tensor]: A dictionary of loss components.
  304. """
  305. num_dec_layers = len(all_cls_scores)
  306. batch_gt_instances_list = [
  307. batch_gt_instances for _ in range(num_dec_layers)
  308. ]
  309. img_metas_list = [batch_img_metas for _ in range(num_dec_layers)]
  310. losses_cls, losses_mask, losses_dice = multi_apply(
  311. self._loss_by_feat_single, all_cls_scores, all_mask_preds,
  312. batch_gt_instances_list, img_metas_list)
  313. loss_dict = dict()
  314. # loss from the last decoder layer
  315. loss_dict['loss_cls'] = losses_cls[-1]
  316. loss_dict['loss_mask'] = losses_mask[-1]
  317. loss_dict['loss_dice'] = losses_dice[-1]
  318. # loss from other decoder layers
  319. num_dec_layer = 0
  320. for loss_cls_i, loss_mask_i, loss_dice_i in zip(
  321. losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]):
  322. loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
  323. loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i
  324. loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i
  325. num_dec_layer += 1
  326. return loss_dict
  327. def _loss_by_feat_single(self, cls_scores: Tensor, mask_preds: Tensor,
  328. batch_gt_instances: List[InstanceData],
  329. batch_img_metas: List[dict]) -> Tuple[Tensor]:
  330. """Loss function for outputs from a single decoder layer.
  331. Args:
  332. cls_scores (Tensor): Mask score logits from a single decoder layer
  333. for all images. Shape (batch_size, num_queries,
  334. cls_out_channels). Note `cls_out_channels` should includes
  335. background.
  336. mask_preds (Tensor): Mask logits for a pixel decoder for all
  337. images. Shape (batch_size, num_queries, h, w).
  338. batch_gt_instances (list[obj:`InstanceData`]): each contains
  339. ``labels`` and ``masks``.
  340. batch_img_metas (list[dict]): List of image meta information.
  341. Returns:
  342. tuple[Tensor]: Loss components for outputs from a single decoder\
  343. layer.
  344. """
  345. num_imgs = cls_scores.size(0)
  346. cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
  347. mask_preds_list = [mask_preds[i] for i in range(num_imgs)]
  348. (labels_list, label_weights_list, mask_targets_list, mask_weights_list,
  349. avg_factor) = self.get_targets(cls_scores_list, mask_preds_list,
  350. batch_gt_instances, batch_img_metas)
  351. # shape (batch_size, num_queries)
  352. labels = torch.stack(labels_list, dim=0)
  353. # shape (batch_size, num_queries)
  354. label_weights = torch.stack(label_weights_list, dim=0)
  355. # shape (num_total_gts, h, w)
  356. mask_targets = torch.cat(mask_targets_list, dim=0)
  357. # shape (batch_size, num_queries)
  358. mask_weights = torch.stack(mask_weights_list, dim=0)
  359. # classfication loss
  360. # shape (batch_size * num_queries, )
  361. cls_scores = cls_scores.flatten(0, 1)
  362. labels = labels.flatten(0, 1)
  363. label_weights = label_weights.flatten(0, 1)
  364. class_weight = cls_scores.new_tensor(self.class_weight)
  365. loss_cls = self.loss_cls(
  366. cls_scores,
  367. labels,
  368. label_weights,
  369. avg_factor=class_weight[labels].sum())
  370. num_total_masks = reduce_mean(cls_scores.new_tensor([avg_factor]))
  371. num_total_masks = max(num_total_masks, 1)
  372. # extract positive ones
  373. # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
  374. mask_preds = mask_preds[mask_weights > 0]
  375. target_shape = mask_targets.shape[-2:]
  376. if mask_targets.shape[0] == 0:
  377. # zero match
  378. loss_dice = mask_preds.sum()
  379. loss_mask = mask_preds.sum()
  380. return loss_cls, loss_mask, loss_dice
  381. # upsample to shape of target
  382. # shape (num_total_gts, h, w)
  383. mask_preds = F.interpolate(
  384. mask_preds.unsqueeze(1),
  385. target_shape,
  386. mode='bilinear',
  387. align_corners=False).squeeze(1)
  388. # dice loss
  389. loss_dice = self.loss_dice(
  390. mask_preds, mask_targets, avg_factor=num_total_masks)
  391. # mask loss
  392. # FocalLoss support input of shape (n, num_class)
  393. h, w = mask_preds.shape[-2:]
  394. # shape (num_total_gts, h, w) -> (num_total_gts * h * w, 1)
  395. mask_preds = mask_preds.reshape(-1, 1)
  396. # shape (num_total_gts, h, w) -> (num_total_gts * h * w)
  397. mask_targets = mask_targets.reshape(-1)
  398. # target is (1 - mask_targets) !!!
  399. loss_mask = self.loss_mask(
  400. mask_preds, 1 - mask_targets, avg_factor=num_total_masks * h * w)
  401. return loss_cls, loss_mask, loss_dice
  402. def forward(self, x: Tuple[Tensor],
  403. batch_data_samples: SampleList) -> Tuple[Tensor]:
  404. """Forward function.
  405. Args:
  406. x (tuple[Tensor]): Features from the upstream network, each
  407. is a 4D-tensor.
  408. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  409. Samples. It usually includes information such as
  410. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  411. Returns:
  412. tuple[Tensor]: a tuple contains two elements.
  413. - all_cls_scores (Tensor): Classification scores for each\
  414. scale level. Each is a 4D-tensor with shape\
  415. (num_decoder, batch_size, num_queries, cls_out_channels).\
  416. Note `cls_out_channels` should includes background.
  417. - all_mask_preds (Tensor): Mask scores for each decoder\
  418. layer. Each with shape (num_decoder, batch_size,\
  419. num_queries, h, w).
  420. """
  421. batch_img_metas = [
  422. data_sample.metainfo for data_sample in batch_data_samples
  423. ]
  424. batch_size = len(batch_img_metas)
  425. input_img_h, input_img_w = batch_img_metas[0]['batch_input_shape']
  426. padding_mask = x[-1].new_ones((batch_size, input_img_h, input_img_w),
  427. dtype=torch.float32)
  428. for i in range(batch_size):
  429. img_h, img_w = batch_img_metas[i]['img_shape']
  430. padding_mask[i, :img_h, :img_w] = 0
  431. padding_mask = F.interpolate(
  432. padding_mask.unsqueeze(1), size=x[-1].shape[-2:],
  433. mode='nearest').to(torch.bool).squeeze(1)
  434. # when backbone is swin, memory is output of last stage of swin.
  435. # when backbone is r50, memory is output of tranformer encoder.
  436. mask_features, memory = self.pixel_decoder(x, batch_img_metas)
  437. pos_embed = self.decoder_pe(padding_mask)
  438. memory = self.decoder_input_proj(memory)
  439. # shape (batch_size, c, h, w) -> (batch_size, h*w, c)
  440. memory = memory.flatten(2).permute(0, 2, 1)
  441. pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
  442. # shape (batch_size, h * w)
  443. padding_mask = padding_mask.flatten(1)
  444. # shape = (num_queries, embed_dims)
  445. query_embed = self.query_embed.weight
  446. # shape = (batch_size, num_queries, embed_dims)
  447. query_embed = query_embed.unsqueeze(0).repeat(batch_size, 1, 1)
  448. target = torch.zeros_like(query_embed)
  449. # shape (num_decoder, num_queries, batch_size, embed_dims)
  450. out_dec = self.transformer_decoder(
  451. query=target,
  452. key=memory,
  453. value=memory,
  454. query_pos=query_embed,
  455. key_pos=pos_embed,
  456. key_padding_mask=padding_mask)
  457. # cls_scores
  458. all_cls_scores = self.cls_embed(out_dec)
  459. # mask_preds
  460. mask_embed = self.mask_embed(out_dec)
  461. all_mask_preds = torch.einsum('lbqc,bchw->lbqhw', mask_embed,
  462. mask_features)
  463. return all_cls_scores, all_mask_preds
  464. def loss(
  465. self,
  466. x: Tuple[Tensor],
  467. batch_data_samples: SampleList,
  468. ) -> Dict[str, Tensor]:
  469. """Perform forward propagation and loss calculation of the panoptic
  470. head on the features of the upstream network.
  471. Args:
  472. x (tuple[Tensor]): Multi-level features from the upstream
  473. network, each is a 4D-tensor.
  474. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  475. Samples. It usually includes information such as
  476. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  477. Returns:
  478. dict[str, Tensor]: a dictionary of loss components
  479. """
  480. batch_img_metas = []
  481. batch_gt_instances = []
  482. batch_gt_semantic_segs = []
  483. for data_sample in batch_data_samples:
  484. batch_img_metas.append(data_sample.metainfo)
  485. batch_gt_instances.append(data_sample.gt_instances)
  486. if 'gt_sem_seg' in data_sample:
  487. batch_gt_semantic_segs.append(data_sample.gt_sem_seg)
  488. else:
  489. batch_gt_semantic_segs.append(None)
  490. # forward
  491. all_cls_scores, all_mask_preds = self(x, batch_data_samples)
  492. # preprocess ground truth
  493. batch_gt_instances = self.preprocess_gt(batch_gt_instances,
  494. batch_gt_semantic_segs)
  495. # loss
  496. losses = self.loss_by_feat(all_cls_scores, all_mask_preds,
  497. batch_gt_instances, batch_img_metas)
  498. return losses
  499. def predict(self, x: Tuple[Tensor],
  500. batch_data_samples: SampleList) -> Tuple[Tensor]:
  501. """Test without augmentaton.
  502. Args:
  503. x (tuple[Tensor]): Multi-level features from the
  504. upstream network, each is a 4D-tensor.
  505. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  506. Samples. It usually includes information such as
  507. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  508. Returns:
  509. tuple[Tensor]: A tuple contains two tensors.
  510. - mask_cls_results (Tensor): Mask classification logits,\
  511. shape (batch_size, num_queries, cls_out_channels).
  512. Note `cls_out_channels` should includes background.
  513. - mask_pred_results (Tensor): Mask logits, shape \
  514. (batch_size, num_queries, h, w).
  515. """
  516. batch_img_metas = [
  517. data_sample.metainfo for data_sample in batch_data_samples
  518. ]
  519. all_cls_scores, all_mask_preds = self(x, batch_data_samples)
  520. mask_cls_results = all_cls_scores[-1]
  521. mask_pred_results = all_mask_preds[-1]
  522. # upsample masks
  523. img_shape = batch_img_metas[0]['batch_input_shape']
  524. mask_pred_results = F.interpolate(
  525. mask_pred_results,
  526. size=(img_shape[0], img_shape[1]),
  527. mode='bilinear',
  528. align_corners=False)
  529. return mask_cls_results, mask_pred_results