msdeformattn_pixel_decoder.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmcv.cnn import Conv2d, ConvModule
  7. from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
  8. from mmengine.model import (BaseModule, ModuleList, caffe2_xavier_init,
  9. normal_init, xavier_init)
  10. from torch import Tensor
  11. from mmdet.registry import MODELS
  12. from mmdet.utils import ConfigType, OptMultiConfig
  13. from ..task_modules.prior_generators import MlvlPointGenerator
  14. from .positional_encoding import SinePositionalEncoding
  15. from .transformer import Mask2FormerTransformerEncoder
  16. @MODELS.register_module()
  17. class MSDeformAttnPixelDecoder(BaseModule):
  18. """Pixel decoder with multi-scale deformable attention.
  19. Args:
  20. in_channels (list[int] | tuple[int]): Number of channels in the
  21. input feature maps.
  22. strides (list[int] | tuple[int]): Output strides of feature from
  23. backbone.
  24. feat_channels (int): Number of channels for feature.
  25. out_channels (int): Number of channels for output.
  26. num_outs (int): Number of output scales.
  27. norm_cfg (:obj:`ConfigDict` or dict): Config for normalization.
  28. Defaults to dict(type='GN', num_groups=32).
  29. act_cfg (:obj:`ConfigDict` or dict): Config for activation.
  30. Defaults to dict(type='ReLU').
  31. encoder (:obj:`ConfigDict` or dict): Config for transformer
  32. encoder. Defaults to None.
  33. positional_encoding (:obj:`ConfigDict` or dict): Config for
  34. transformer encoder position encoding. Defaults to
  35. dict(num_feats=128, normalize=True).
  36. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  37. dict], optional): Initialization config dict. Defaults to None.
  38. """
  39. def __init__(self,
  40. in_channels: Union[List[int],
  41. Tuple[int]] = [256, 512, 1024, 2048],
  42. strides: Union[List[int], Tuple[int]] = [4, 8, 16, 32],
  43. feat_channels: int = 256,
  44. out_channels: int = 256,
  45. num_outs: int = 3,
  46. norm_cfg: ConfigType = dict(type='GN', num_groups=32),
  47. act_cfg: ConfigType = dict(type='ReLU'),
  48. encoder: ConfigType = None,
  49. positional_encoding: ConfigType = dict(
  50. num_feats=128, normalize=True),
  51. init_cfg: OptMultiConfig = None) -> None:
  52. super().__init__(init_cfg=init_cfg)
  53. self.strides = strides
  54. self.num_input_levels = len(in_channels)
  55. self.num_encoder_levels = \
  56. encoder.layer_cfg.self_attn_cfg.num_levels
  57. assert self.num_encoder_levels >= 1, \
  58. 'num_levels in attn_cfgs must be at least one'
  59. input_conv_list = []
  60. # from top to down (low to high resolution)
  61. for i in range(self.num_input_levels - 1,
  62. self.num_input_levels - self.num_encoder_levels - 1,
  63. -1):
  64. input_conv = ConvModule(
  65. in_channels[i],
  66. feat_channels,
  67. kernel_size=1,
  68. norm_cfg=norm_cfg,
  69. act_cfg=None,
  70. bias=True)
  71. input_conv_list.append(input_conv)
  72. self.input_convs = ModuleList(input_conv_list)
  73. self.encoder = Mask2FormerTransformerEncoder(**encoder)
  74. self.postional_encoding = SinePositionalEncoding(**positional_encoding)
  75. # high resolution to low resolution
  76. self.level_encoding = nn.Embedding(self.num_encoder_levels,
  77. feat_channels)
  78. # fpn-like structure
  79. self.lateral_convs = ModuleList()
  80. self.output_convs = ModuleList()
  81. self.use_bias = norm_cfg is None
  82. # from top to down (low to high resolution)
  83. # fpn for the rest features that didn't pass in encoder
  84. for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1,
  85. -1):
  86. lateral_conv = ConvModule(
  87. in_channels[i],
  88. feat_channels,
  89. kernel_size=1,
  90. bias=self.use_bias,
  91. norm_cfg=norm_cfg,
  92. act_cfg=None)
  93. output_conv = ConvModule(
  94. feat_channels,
  95. feat_channels,
  96. kernel_size=3,
  97. stride=1,
  98. padding=1,
  99. bias=self.use_bias,
  100. norm_cfg=norm_cfg,
  101. act_cfg=act_cfg)
  102. self.lateral_convs.append(lateral_conv)
  103. self.output_convs.append(output_conv)
  104. self.mask_feature = Conv2d(
  105. feat_channels, out_channels, kernel_size=1, stride=1, padding=0)
  106. self.num_outs = num_outs
  107. self.point_generator = MlvlPointGenerator(strides)
  108. def init_weights(self) -> None:
  109. """Initialize weights."""
  110. for i in range(0, self.num_encoder_levels):
  111. xavier_init(
  112. self.input_convs[i].conv,
  113. gain=1,
  114. bias=0,
  115. distribution='uniform')
  116. for i in range(0, self.num_input_levels - self.num_encoder_levels):
  117. caffe2_xavier_init(self.lateral_convs[i].conv, bias=0)
  118. caffe2_xavier_init(self.output_convs[i].conv, bias=0)
  119. caffe2_xavier_init(self.mask_feature, bias=0)
  120. normal_init(self.level_encoding, mean=0, std=1)
  121. for p in self.encoder.parameters():
  122. if p.dim() > 1:
  123. nn.init.xavier_normal_(p)
  124. # init_weights defined in MultiScaleDeformableAttention
  125. for m in self.encoder.layers.modules():
  126. if isinstance(m, MultiScaleDeformableAttention):
  127. m.init_weights()
  128. def forward(self, feats: List[Tensor]) -> Tuple[Tensor, Tensor]:
  129. """
  130. Args:
  131. feats (list[Tensor]): Feature maps of each level. Each has
  132. shape of (batch_size, c, h, w).
  133. Returns:
  134. tuple: A tuple containing the following:
  135. - mask_feature (Tensor): shape (batch_size, c, h, w).
  136. - multi_scale_features (list[Tensor]): Multi scale \
  137. features, each in shape (batch_size, c, h, w).
  138. """
  139. # generate padding mask for each level, for each image
  140. batch_size = feats[0].shape[0]
  141. encoder_input_list = []
  142. padding_mask_list = []
  143. level_positional_encoding_list = []
  144. spatial_shapes = []
  145. reference_points_list = []
  146. for i in range(self.num_encoder_levels):
  147. level_idx = self.num_input_levels - i - 1
  148. feat = feats[level_idx]
  149. feat_projected = self.input_convs[i](feat)
  150. h, w = feat.shape[-2:]
  151. # no padding
  152. padding_mask_resized = feat.new_zeros(
  153. (batch_size, ) + feat.shape[-2:], dtype=torch.bool)
  154. pos_embed = self.postional_encoding(padding_mask_resized)
  155. level_embed = self.level_encoding.weight[i]
  156. level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed
  157. # (h_i * w_i, 2)
  158. reference_points = self.point_generator.single_level_grid_priors(
  159. feat.shape[-2:], level_idx, device=feat.device)
  160. # normalize
  161. factor = feat.new_tensor([[w, h]]) * self.strides[level_idx]
  162. reference_points = reference_points / factor
  163. # shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c)
  164. feat_projected = feat_projected.flatten(2).permute(0, 2, 1)
  165. level_pos_embed = level_pos_embed.flatten(2).permute(0, 2, 1)
  166. padding_mask_resized = padding_mask_resized.flatten(1)
  167. encoder_input_list.append(feat_projected)
  168. padding_mask_list.append(padding_mask_resized)
  169. level_positional_encoding_list.append(level_pos_embed)
  170. spatial_shapes.append(feat.shape[-2:])
  171. reference_points_list.append(reference_points)
  172. # shape (batch_size, total_num_queries),
  173. # total_num_queries=sum([., h_i * w_i,.])
  174. padding_masks = torch.cat(padding_mask_list, dim=1)
  175. # shape (total_num_queries, batch_size, c)
  176. encoder_inputs = torch.cat(encoder_input_list, dim=1)
  177. level_positional_encodings = torch.cat(
  178. level_positional_encoding_list, dim=1)
  179. device = encoder_inputs.device
  180. # shape (num_encoder_levels, 2), from low
  181. # resolution to high resolution
  182. spatial_shapes = torch.as_tensor(
  183. spatial_shapes, dtype=torch.long, device=device)
  184. # shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...)
  185. level_start_index = torch.cat((spatial_shapes.new_zeros(
  186. (1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
  187. reference_points = torch.cat(reference_points_list, dim=0)
  188. reference_points = reference_points[None, :, None].repeat(
  189. batch_size, 1, self.num_encoder_levels, 1)
  190. valid_radios = reference_points.new_ones(
  191. (batch_size, self.num_encoder_levels, 2))
  192. # shape (num_total_queries, batch_size, c)
  193. memory = self.encoder(
  194. query=encoder_inputs,
  195. query_pos=level_positional_encodings,
  196. key_padding_mask=padding_masks,
  197. spatial_shapes=spatial_shapes,
  198. reference_points=reference_points,
  199. level_start_index=level_start_index,
  200. valid_ratios=valid_radios)
  201. # (batch_size, c, num_total_queries)
  202. memory = memory.permute(0, 2, 1)
  203. # from low resolution to high resolution
  204. num_queries_per_level = [e[0] * e[1] for e in spatial_shapes]
  205. outs = torch.split(memory, num_queries_per_level, dim=-1)
  206. outs = [
  207. x.reshape(batch_size, -1, spatial_shapes[i][0],
  208. spatial_shapes[i][1]) for i, x in enumerate(outs)
  209. ]
  210. for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1,
  211. -1):
  212. x = feats[i]
  213. cur_feat = self.lateral_convs[i](x)
  214. y = cur_feat + F.interpolate(
  215. outs[-1],
  216. size=cur_feat.shape[-2:],
  217. mode='bilinear',
  218. align_corners=False)
  219. y = self.output_convs[i](y)
  220. outs.append(y)
  221. multi_scale_features = outs[:self.num_outs]
  222. mask_feature = self.mask_feature(outs[-1])
  223. return mask_feature, multi_scale_features