mask2former_head.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. from typing import List, Tuple
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from mmcv.cnn import Conv2d
  8. from mmcv.ops import point_sample
  9. from mmengine.model import ModuleList, caffe2_xavier_init
  10. from mmengine.structures import InstanceData
  11. from torch import Tensor
  12. from mmdet.registry import MODELS, TASK_UTILS
  13. from mmdet.structures import SampleList
  14. from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig, reduce_mean
  15. from ..layers import Mask2FormerTransformerDecoder, SinePositionalEncoding
  16. from ..utils import get_uncertain_point_coords_with_randomness
  17. from .anchor_free_head import AnchorFreeHead
  18. from .maskformer_head import MaskFormerHead
  19. @MODELS.register_module()
  20. class Mask2FormerHead(MaskFormerHead):
  21. """Implements the Mask2Former head.
  22. See `Masked-attention Mask Transformer for Universal Image
  23. Segmentation <https://arxiv.org/pdf/2112.01527>`_ for details.
  24. Args:
  25. in_channels (list[int]): Number of channels in the input feature map.
  26. feat_channels (int): Number of channels for features.
  27. out_channels (int): Number of channels for output.
  28. num_things_classes (int): Number of things.
  29. num_stuff_classes (int): Number of stuff.
  30. num_queries (int): Number of query in Transformer decoder.
  31. pixel_decoder (:obj:`ConfigDict` or dict): Config for pixel
  32. decoder. Defaults to None.
  33. enforce_decoder_input_project (bool, optional): Whether to add
  34. a layer to change the embed_dim of tranformer encoder in
  35. pixel decoder to the embed_dim of transformer decoder.
  36. Defaults to False.
  37. transformer_decoder (:obj:`ConfigDict` or dict): Config for
  38. transformer decoder. Defaults to None.
  39. positional_encoding (:obj:`ConfigDict` or dict): Config for
  40. transformer decoder position encoding. Defaults to
  41. dict(num_feats=128, normalize=True).
  42. loss_cls (:obj:`ConfigDict` or dict): Config of the classification
  43. loss. Defaults to None.
  44. loss_mask (:obj:`ConfigDict` or dict): Config of the mask loss.
  45. Defaults to None.
  46. loss_dice (:obj:`ConfigDict` or dict): Config of the dice loss.
  47. Defaults to None.
  48. train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
  49. Mask2Former head.
  50. test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
  51. Mask2Former head.
  52. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  53. dict], optional): Initialization config dict. Defaults to None.
  54. """
  55. def __init__(self,
  56. in_channels: List[int],
  57. feat_channels: int,
  58. out_channels: int,
  59. num_things_classes: int = 80,
  60. num_stuff_classes: int = 53,
  61. num_queries: int = 100,
  62. num_transformer_feat_level: int = 3,
  63. pixel_decoder: ConfigType = ...,
  64. enforce_decoder_input_project: bool = False,
  65. transformer_decoder: ConfigType = ...,
  66. positional_encoding: ConfigType = dict(
  67. num_feats=128, normalize=True),
  68. loss_cls: ConfigType = dict(
  69. type='CrossEntropyLoss',
  70. use_sigmoid=False,
  71. loss_weight=2.0,
  72. reduction='mean',
  73. class_weight=[1.0] * 133 + [0.1]),
  74. loss_mask: ConfigType = dict(
  75. type='CrossEntropyLoss',
  76. use_sigmoid=True,
  77. reduction='mean',
  78. loss_weight=5.0),
  79. loss_dice: ConfigType = dict(
  80. type='DiceLoss',
  81. use_sigmoid=True,
  82. activate=True,
  83. reduction='mean',
  84. naive_dice=True,
  85. eps=1.0,
  86. loss_weight=5.0),
  87. train_cfg: OptConfigType = None,
  88. test_cfg: OptConfigType = None,
  89. init_cfg: OptMultiConfig = None,
  90. **kwargs) -> None:
  91. super(AnchorFreeHead, self).__init__(init_cfg=init_cfg)
  92. self.num_things_classes = num_things_classes
  93. self.num_stuff_classes = num_stuff_classes
  94. self.num_classes = self.num_things_classes + self.num_stuff_classes
  95. self.num_queries = num_queries
  96. self.num_transformer_feat_level = num_transformer_feat_level
  97. self.num_heads = transformer_decoder.layer_cfg.cross_attn_cfg.num_heads
  98. self.num_transformer_decoder_layers = transformer_decoder.num_layers
  99. assert pixel_decoder.encoder.layer_cfg. \
  100. self_attn_cfg.num_levels == num_transformer_feat_level
  101. pixel_decoder_ = copy.deepcopy(pixel_decoder)
  102. pixel_decoder_.update(
  103. in_channels=in_channels,
  104. feat_channels=feat_channels,
  105. out_channels=out_channels)
  106. self.pixel_decoder = MODELS.build(pixel_decoder_)
  107. self.transformer_decoder = Mask2FormerTransformerDecoder(
  108. **transformer_decoder)
  109. self.decoder_embed_dims = self.transformer_decoder.embed_dims
  110. self.decoder_input_projs = ModuleList()
  111. # from low resolution to high resolution
  112. for _ in range(num_transformer_feat_level):
  113. if (self.decoder_embed_dims != feat_channels
  114. or enforce_decoder_input_project):
  115. self.decoder_input_projs.append(
  116. Conv2d(
  117. feat_channels, self.decoder_embed_dims, kernel_size=1))
  118. else:
  119. self.decoder_input_projs.append(nn.Identity())
  120. self.decoder_positional_encoding = SinePositionalEncoding(
  121. **positional_encoding)
  122. self.query_embed = nn.Embedding(self.num_queries, feat_channels)
  123. self.query_feat = nn.Embedding(self.num_queries, feat_channels)
  124. # from low resolution to high resolution
  125. self.level_embed = nn.Embedding(self.num_transformer_feat_level,
  126. feat_channels)
  127. self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
  128. self.mask_embed = nn.Sequential(
  129. nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
  130. nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
  131. nn.Linear(feat_channels, out_channels))
  132. self.test_cfg = test_cfg
  133. self.train_cfg = train_cfg
  134. if train_cfg:
  135. self.assigner = TASK_UTILS.build(self.train_cfg['assigner'])
  136. self.sampler = TASK_UTILS.build(
  137. self.train_cfg['sampler'], default_args=dict(context=self))
  138. self.num_points = self.train_cfg.get('num_points', 12544)
  139. self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0)
  140. self.importance_sample_ratio = self.train_cfg.get(
  141. 'importance_sample_ratio', 0.75)
  142. self.class_weight = loss_cls.class_weight
  143. self.loss_cls = MODELS.build(loss_cls)
  144. self.loss_mask = MODELS.build(loss_mask)
  145. self.loss_dice = MODELS.build(loss_dice)
  146. def init_weights(self) -> None:
  147. for m in self.decoder_input_projs:
  148. if isinstance(m, Conv2d):
  149. caffe2_xavier_init(m, bias=0)
  150. self.pixel_decoder.init_weights()
  151. for p in self.transformer_decoder.parameters():
  152. if p.dim() > 1:
  153. nn.init.xavier_normal_(p)
  154. def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor,
  155. gt_instances: InstanceData,
  156. img_meta: dict) -> Tuple[Tensor]:
  157. """Compute classification and mask targets for one image.
  158. Args:
  159. cls_score (Tensor): Mask score logits from a single decoder layer
  160. for one image. Shape (num_queries, cls_out_channels).
  161. mask_pred (Tensor): Mask logits for a single decoder layer for one
  162. image. Shape (num_queries, h, w).
  163. gt_instances (:obj:`InstanceData`): It contains ``labels`` and
  164. ``masks``.
  165. img_meta (dict): Image informtation.
  166. Returns:
  167. tuple[Tensor]: A tuple containing the following for one image.
  168. - labels (Tensor): Labels of each image. \
  169. shape (num_queries, ).
  170. - label_weights (Tensor): Label weights of each image. \
  171. shape (num_queries, ).
  172. - mask_targets (Tensor): Mask targets of each image. \
  173. shape (num_queries, h, w).
  174. - mask_weights (Tensor): Mask weights of each image. \
  175. shape (num_queries, ).
  176. - pos_inds (Tensor): Sampled positive indices for each \
  177. image.
  178. - neg_inds (Tensor): Sampled negative indices for each \
  179. image.
  180. - sampling_result (:obj:`SamplingResult`): Sampling results.
  181. """
  182. gt_labels = gt_instances.labels
  183. gt_masks = gt_instances.masks
  184. # sample points
  185. num_queries = cls_score.shape[0]
  186. num_gts = gt_labels.shape[0]
  187. point_coords = torch.rand((1, self.num_points, 2),
  188. device=cls_score.device)
  189. # shape (num_queries, num_points)
  190. mask_points_pred = point_sample(
  191. mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1,
  192. 1)).squeeze(1)
  193. # shape (num_gts, num_points)
  194. gt_points_masks = point_sample(
  195. gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1,
  196. 1)).squeeze(1)
  197. sampled_gt_instances = InstanceData(
  198. labels=gt_labels, masks=gt_points_masks)
  199. sampled_pred_instances = InstanceData(
  200. scores=cls_score, masks=mask_points_pred)
  201. # assign and sample
  202. assign_result = self.assigner.assign(
  203. pred_instances=sampled_pred_instances,
  204. gt_instances=sampled_gt_instances,
  205. img_meta=img_meta)
  206. pred_instances = InstanceData(scores=cls_score, masks=mask_pred)
  207. sampling_result = self.sampler.sample(
  208. assign_result=assign_result,
  209. pred_instances=pred_instances,
  210. gt_instances=gt_instances)
  211. pos_inds = sampling_result.pos_inds
  212. neg_inds = sampling_result.neg_inds
  213. # label target
  214. labels = gt_labels.new_full((self.num_queries, ),
  215. self.num_classes,
  216. dtype=torch.long)
  217. labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
  218. label_weights = gt_labels.new_ones((self.num_queries, ))
  219. # mask target
  220. mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds]
  221. mask_weights = mask_pred.new_zeros((self.num_queries, ))
  222. mask_weights[pos_inds] = 1.0
  223. return (labels, label_weights, mask_targets, mask_weights, pos_inds,
  224. neg_inds, sampling_result)
  225. def _loss_by_feat_single(self, cls_scores: Tensor, mask_preds: Tensor,
  226. batch_gt_instances: List[InstanceData],
  227. batch_img_metas: List[dict]) -> Tuple[Tensor]:
  228. """Loss function for outputs from a single decoder layer.
  229. Args:
  230. cls_scores (Tensor): Mask score logits from a single decoder layer
  231. for all images. Shape (batch_size, num_queries,
  232. cls_out_channels). Note `cls_out_channels` should includes
  233. background.
  234. mask_preds (Tensor): Mask logits for a pixel decoder for all
  235. images. Shape (batch_size, num_queries, h, w).
  236. batch_gt_instances (list[obj:`InstanceData`]): each contains
  237. ``labels`` and ``masks``.
  238. batch_img_metas (list[dict]): List of image meta information.
  239. Returns:
  240. tuple[Tensor]: Loss components for outputs from a single \
  241. decoder layer.
  242. """
  243. num_imgs = cls_scores.size(0)
  244. cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
  245. mask_preds_list = [mask_preds[i] for i in range(num_imgs)]
  246. (labels_list, label_weights_list, mask_targets_list, mask_weights_list,
  247. avg_factor) = self.get_targets(cls_scores_list, mask_preds_list,
  248. batch_gt_instances, batch_img_metas)
  249. # shape (batch_size, num_queries)
  250. labels = torch.stack(labels_list, dim=0)
  251. # shape (batch_size, num_queries)
  252. label_weights = torch.stack(label_weights_list, dim=0)
  253. # shape (num_total_gts, h, w)
  254. mask_targets = torch.cat(mask_targets_list, dim=0)
  255. # shape (batch_size, num_queries)
  256. mask_weights = torch.stack(mask_weights_list, dim=0)
  257. # classfication loss
  258. # shape (batch_size * num_queries, )
  259. cls_scores = cls_scores.flatten(0, 1)
  260. labels = labels.flatten(0, 1)
  261. label_weights = label_weights.flatten(0, 1)
  262. class_weight = cls_scores.new_tensor(self.class_weight)
  263. loss_cls = self.loss_cls(
  264. cls_scores,
  265. labels,
  266. label_weights,
  267. avg_factor=class_weight[labels].sum())
  268. num_total_masks = reduce_mean(cls_scores.new_tensor([avg_factor]))
  269. num_total_masks = max(num_total_masks, 1)
  270. # extract positive ones
  271. # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
  272. mask_preds = mask_preds[mask_weights > 0]
  273. if mask_targets.shape[0] == 0:
  274. # zero match
  275. loss_dice = mask_preds.sum()
  276. loss_mask = mask_preds.sum()
  277. return loss_cls, loss_mask, loss_dice
  278. with torch.no_grad():
  279. points_coords = get_uncertain_point_coords_with_randomness(
  280. mask_preds.unsqueeze(1), None, self.num_points,
  281. self.oversample_ratio, self.importance_sample_ratio)
  282. # shape (num_total_gts, h, w) -> (num_total_gts, num_points)
  283. mask_point_targets = point_sample(
  284. mask_targets.unsqueeze(1).float(), points_coords).squeeze(1)
  285. # shape (num_queries, h, w) -> (num_queries, num_points)
  286. mask_point_preds = point_sample(
  287. mask_preds.unsqueeze(1), points_coords).squeeze(1)
  288. # dice loss
  289. loss_dice = self.loss_dice(
  290. mask_point_preds, mask_point_targets, avg_factor=num_total_masks)
  291. # mask loss
  292. # shape (num_queries, num_points) -> (num_queries * num_points, )
  293. mask_point_preds = mask_point_preds.reshape(-1)
  294. # shape (num_total_gts, num_points) -> (num_total_gts * num_points, )
  295. mask_point_targets = mask_point_targets.reshape(-1)
  296. loss_mask = self.loss_mask(
  297. mask_point_preds,
  298. mask_point_targets,
  299. avg_factor=num_total_masks * self.num_points)
  300. return loss_cls, loss_mask, loss_dice
  301. def _forward_head(self, decoder_out: Tensor, mask_feature: Tensor,
  302. attn_mask_target_size: Tuple[int, int]) -> Tuple[Tensor]:
  303. """Forward for head part which is called after every decoder layer.
  304. Args:
  305. decoder_out (Tensor): in shape (batch_size, num_queries, c).
  306. mask_feature (Tensor): in shape (batch_size, c, h, w).
  307. attn_mask_target_size (tuple[int, int]): target attention
  308. mask size.
  309. Returns:
  310. tuple: A tuple contain three elements.
  311. - cls_pred (Tensor): Classification scores in shape \
  312. (batch_size, num_queries, cls_out_channels). \
  313. Note `cls_out_channels` should includes background.
  314. - mask_pred (Tensor): Mask scores in shape \
  315. (batch_size, num_queries,h, w).
  316. - attn_mask (Tensor): Attention mask in shape \
  317. (batch_size * num_heads, num_queries, h, w).
  318. """
  319. decoder_out = self.transformer_decoder.post_norm(decoder_out)
  320. # shape (num_queries, batch_size, c)
  321. cls_pred = self.cls_embed(decoder_out)
  322. # shape (num_queries, batch_size, c)
  323. mask_embed = self.mask_embed(decoder_out)
  324. # shape (num_queries, batch_size, h, w)
  325. mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature)
  326. attn_mask = F.interpolate(
  327. mask_pred,
  328. attn_mask_target_size,
  329. mode='bilinear',
  330. align_corners=False)
  331. # shape (num_queries, batch_size, h, w) ->
  332. # (batch_size * num_head, num_queries, h, w)
  333. attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat(
  334. (1, self.num_heads, 1, 1)).flatten(0, 1)
  335. attn_mask = attn_mask.sigmoid() < 0.5
  336. attn_mask = attn_mask.detach()
  337. return cls_pred, mask_pred, attn_mask
  338. def forward(self, x: List[Tensor],
  339. batch_data_samples: SampleList) -> Tuple[List[Tensor]]:
  340. """Forward function.
  341. Args:
  342. x (list[Tensor]): Multi scale Features from the
  343. upstream network, each is a 4D-tensor.
  344. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  345. Samples. It usually includes information such as
  346. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  347. Returns:
  348. tuple[list[Tensor]]: A tuple contains two elements.
  349. - cls_pred_list (list[Tensor)]: Classification logits \
  350. for each decoder layer. Each is a 3D-tensor with shape \
  351. (batch_size, num_queries, cls_out_channels). \
  352. Note `cls_out_channels` should includes background.
  353. - mask_pred_list (list[Tensor]): Mask logits for each \
  354. decoder layer. Each with shape (batch_size, num_queries, \
  355. h, w).
  356. """
  357. batch_img_metas = [
  358. data_sample.metainfo for data_sample in batch_data_samples
  359. ]
  360. batch_size = len(batch_img_metas)
  361. mask_features, multi_scale_memorys = self.pixel_decoder(x)
  362. # multi_scale_memorys (from low resolution to high resolution)
  363. decoder_inputs = []
  364. decoder_positional_encodings = []
  365. for i in range(self.num_transformer_feat_level):
  366. decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i])
  367. # shape (batch_size, c, h, w) -> (batch_size, h*w, c)
  368. decoder_input = decoder_input.flatten(2).permute(0, 2, 1)
  369. level_embed = self.level_embed.weight[i].view(1, 1, -1)
  370. decoder_input = decoder_input + level_embed
  371. # shape (batch_size, c, h, w) -> (batch_size, h*w, c)
  372. mask = decoder_input.new_zeros(
  373. (batch_size, ) + multi_scale_memorys[i].shape[-2:],
  374. dtype=torch.bool)
  375. decoder_positional_encoding = self.decoder_positional_encoding(
  376. mask)
  377. decoder_positional_encoding = decoder_positional_encoding.flatten(
  378. 2).permute(0, 2, 1)
  379. decoder_inputs.append(decoder_input)
  380. decoder_positional_encodings.append(decoder_positional_encoding)
  381. # shape (num_queries, c) -> (batch_size, num_queries, c)
  382. query_feat = self.query_feat.weight.unsqueeze(0).repeat(
  383. (batch_size, 1, 1))
  384. query_embed = self.query_embed.weight.unsqueeze(0).repeat(
  385. (batch_size, 1, 1))
  386. cls_pred_list = []
  387. mask_pred_list = []
  388. cls_pred, mask_pred, attn_mask = self._forward_head(
  389. query_feat, mask_features, multi_scale_memorys[0].shape[-2:])
  390. cls_pred_list.append(cls_pred)
  391. mask_pred_list.append(mask_pred)
  392. for i in range(self.num_transformer_decoder_layers):
  393. level_idx = i % self.num_transformer_feat_level
  394. # if a mask is all True(all background), then set it all False.
  395. attn_mask[torch.where(
  396. attn_mask.sum(-1) == attn_mask.shape[-1])] = False
  397. # cross_attn + self_attn
  398. layer = self.transformer_decoder.layers[i]
  399. query_feat = layer(
  400. query=query_feat,
  401. key=decoder_inputs[level_idx],
  402. value=decoder_inputs[level_idx],
  403. query_pos=query_embed,
  404. key_pos=decoder_positional_encodings[level_idx],
  405. cross_attn_mask=attn_mask,
  406. query_key_padding_mask=None,
  407. # here we do not apply masking on padded region
  408. key_padding_mask=None)
  409. cls_pred, mask_pred, attn_mask = self._forward_head(
  410. query_feat, mask_features, multi_scale_memorys[
  411. (i + 1) % self.num_transformer_feat_level].shape[-2:])
  412. cls_pred_list.append(cls_pred)
  413. mask_pred_list.append(mask_pred)
  414. return cls_pred_list, mask_pred_list