deformable_detr.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. from typing import Dict, Tuple
  4. import torch
  5. import torch.nn.functional as F
  6. from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
  7. from mmengine.model import xavier_init
  8. from torch import Tensor, nn
  9. from torch.nn.init import normal_
  10. from mmdet.registry import MODELS
  11. from mmdet.structures import OptSampleList
  12. from mmdet.utils import OptConfigType
  13. from ..layers import (DeformableDetrTransformerDecoder,
  14. DeformableDetrTransformerEncoder, SinePositionalEncoding)
  15. from .base_detr import DetectionTransformer
  16. @MODELS.register_module()
  17. class DeformableDETR(DetectionTransformer):
  18. r"""Implementation of `Deformable DETR: Deformable Transformers for
  19. End-to-End Object Detection <https://arxiv.org/abs/2010.04159>`_
  20. Code is modified from the `official github repo
  21. <https://github.com/fundamentalvision/Deformable-DETR>`_.
  22. Args:
  23. decoder (:obj:`ConfigDict` or dict, optional): Config of the
  24. Transformer decoder. Defaults to None.
  25. bbox_head (:obj:`ConfigDict` or dict, optional): Config for the
  26. bounding box head module. Defaults to None.
  27. with_box_refine (bool, optional): Whether to refine the references
  28. in the decoder. Defaults to `False`.
  29. as_two_stage (bool, optional): Whether to generate the proposal
  30. from the outputs of encoder. Defaults to `False`.
  31. num_feature_levels (int, optional): Number of feature levels.
  32. Defaults to 4.
  33. """
  34. def __init__(self,
  35. *args,
  36. decoder: OptConfigType = None,
  37. bbox_head: OptConfigType = None,
  38. with_box_refine: bool = False,
  39. as_two_stage: bool = False,
  40. num_feature_levels: int = 4,
  41. **kwargs) -> None:
  42. self.with_box_refine = with_box_refine
  43. self.as_two_stage = as_two_stage
  44. self.num_feature_levels = num_feature_levels
  45. if bbox_head is not None:
  46. assert 'share_pred_layer' not in bbox_head and \
  47. 'num_pred_layer' not in bbox_head and \
  48. 'as_two_stage' not in bbox_head, \
  49. 'The two keyword args `share_pred_layer`, `num_pred_layer`, ' \
  50. 'and `as_two_stage are set in `detector.__init__()`, users ' \
  51. 'should not set them in `bbox_head` config.'
  52. # The last prediction layer is used to generate proposal
  53. # from encode feature map when `as_two_stage` is `True`.
  54. # And all the prediction layers should share parameters
  55. # when `with_box_refine` is `True`.
  56. bbox_head['share_pred_layer'] = not with_box_refine
  57. bbox_head['num_pred_layer'] = (decoder['num_layers'] + 1) \
  58. if self.as_two_stage else decoder['num_layers']
  59. bbox_head['as_two_stage'] = as_two_stage
  60. super().__init__(*args, decoder=decoder, bbox_head=bbox_head, **kwargs)
  61. def _init_layers(self) -> None:
  62. """Initialize layers except for backbone, neck and bbox_head."""
  63. self.positional_encoding = SinePositionalEncoding(
  64. **self.positional_encoding)
  65. self.encoder = DeformableDetrTransformerEncoder(**self.encoder)
  66. self.decoder = DeformableDetrTransformerDecoder(**self.decoder)
  67. self.embed_dims = self.encoder.embed_dims
  68. if not self.as_two_stage:
  69. self.query_embedding = nn.Embedding(self.num_queries,
  70. self.embed_dims * 2)
  71. # NOTE The query_embedding will be split into query and query_pos
  72. # in self.pre_decoder, hence, the embed_dims are doubled.
  73. num_feats = self.positional_encoding.num_feats
  74. assert num_feats * 2 == self.embed_dims, \
  75. 'embed_dims should be exactly 2 times of num_feats. ' \
  76. f'Found {self.embed_dims} and {num_feats}.'
  77. self.level_embed = nn.Parameter(
  78. torch.Tensor(self.num_feature_levels, self.embed_dims))
  79. if self.as_two_stage:
  80. self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims)
  81. self.memory_trans_norm = nn.LayerNorm(self.embed_dims)
  82. self.pos_trans_fc = nn.Linear(self.embed_dims * 2,
  83. self.embed_dims * 2)
  84. self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2)
  85. else:
  86. self.reference_points_fc = nn.Linear(self.embed_dims, 2)
  87. def init_weights(self) -> None:
  88. """Initialize weights for Transformer and other components."""
  89. super().init_weights()
  90. for coder in self.encoder, self.decoder:
  91. for p in coder.parameters():
  92. if p.dim() > 1:
  93. nn.init.xavier_uniform_(p)
  94. for m in self.modules():
  95. if isinstance(m, MultiScaleDeformableAttention):
  96. m.init_weights()
  97. if self.as_two_stage:
  98. nn.init.xavier_uniform_(self.memory_trans_fc.weight)
  99. nn.init.xavier_uniform_(self.pos_trans_fc.weight)
  100. else:
  101. xavier_init(
  102. self.reference_points_fc, distribution='uniform', bias=0.)
  103. normal_(self.level_embed)
  104. def pre_transformer(
  105. self,
  106. mlvl_feats: Tuple[Tensor],
  107. batch_data_samples: OptSampleList = None) -> Tuple[Dict]:
  108. """Process image features before feeding them to the transformer.
  109. The forward procedure of the transformer is defined as:
  110. 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
  111. More details can be found at `TransformerDetector.forward_transformer`
  112. in `mmdet/detector/base_detr.py`.
  113. Args:
  114. mlvl_feats (tuple[Tensor]): Multi-level features that may have
  115. different resolutions, output from neck. Each feature has
  116. shape (bs, dim, h_lvl, w_lvl), where 'lvl' means 'layer'.
  117. batch_data_samples (list[:obj:`DetDataSample`], optional): The
  118. batch data samples. It usually includes information such
  119. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  120. Defaults to None.
  121. Returns:
  122. tuple[dict]: The first dict contains the inputs of encoder and the
  123. second dict contains the inputs of decoder.
  124. - encoder_inputs_dict (dict): The keyword args dictionary of
  125. `self.forward_encoder()`, which includes 'feat', 'feat_mask',
  126. and 'feat_pos'.
  127. - decoder_inputs_dict (dict): The keyword args dictionary of
  128. `self.forward_decoder()`, which includes 'memory_mask'.
  129. """
  130. batch_size = mlvl_feats[0].size(0)
  131. # construct binary masks for the transformer.
  132. assert batch_data_samples is not None
  133. batch_input_shape = batch_data_samples[0].batch_input_shape
  134. img_shape_list = [sample.img_shape for sample in batch_data_samples]
  135. input_img_h, input_img_w = batch_input_shape
  136. masks = mlvl_feats[0].new_ones((batch_size, input_img_h, input_img_w))
  137. for img_id in range(batch_size):
  138. img_h, img_w = img_shape_list[img_id]
  139. masks[img_id, :img_h, :img_w] = 0
  140. # NOTE following the official DETR repo, non-zero values representing
  141. # ignored positions, while zero values means valid positions.
  142. mlvl_masks = []
  143. mlvl_pos_embeds = []
  144. for feat in mlvl_feats:
  145. mlvl_masks.append(
  146. F.interpolate(masks[None],
  147. size=feat.shape[-2:]).to(torch.bool).squeeze(0))
  148. mlvl_pos_embeds.append(self.positional_encoding(mlvl_masks[-1]))
  149. feat_flatten = []
  150. lvl_pos_embed_flatten = []
  151. mask_flatten = []
  152. spatial_shapes = []
  153. for lvl, (feat, mask, pos_embed) in enumerate(
  154. zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
  155. batch_size, c, h, w = feat.shape
  156. # [bs, c, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl, c]
  157. feat = feat.view(batch_size, c, -1).permute(0, 2, 1)
  158. pos_embed = pos_embed.view(batch_size, c, -1).permute(0, 2, 1)
  159. lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
  160. # [bs, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl]
  161. mask = mask.flatten(1)
  162. spatial_shape = (h, w)
  163. feat_flatten.append(feat)
  164. lvl_pos_embed_flatten.append(lvl_pos_embed)
  165. mask_flatten.append(mask)
  166. spatial_shapes.append(spatial_shape)
  167. # (bs, num_feat_points, dim)
  168. feat_flatten = torch.cat(feat_flatten, 1)
  169. lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
  170. # (bs, num_feat_points), where num_feat_points = sum_lvl(h_lvl*w_lvl)
  171. mask_flatten = torch.cat(mask_flatten, 1)
  172. spatial_shapes = torch.as_tensor( # (num_level, 2)
  173. spatial_shapes,
  174. dtype=torch.long,
  175. device=feat_flatten.device)
  176. level_start_index = torch.cat((
  177. spatial_shapes.new_zeros((1, )), # (num_level)
  178. spatial_shapes.prod(1).cumsum(0)[:-1]))
  179. valid_ratios = torch.stack( # (bs, num_level, 2)
  180. [self.get_valid_ratio(m) for m in mlvl_masks], 1)
  181. encoder_inputs_dict = dict(
  182. feat=feat_flatten,
  183. feat_mask=mask_flatten,
  184. feat_pos=lvl_pos_embed_flatten,
  185. spatial_shapes=spatial_shapes,
  186. level_start_index=level_start_index,
  187. valid_ratios=valid_ratios)
  188. decoder_inputs_dict = dict(
  189. memory_mask=mask_flatten,
  190. spatial_shapes=spatial_shapes,
  191. level_start_index=level_start_index,
  192. valid_ratios=valid_ratios)
  193. return encoder_inputs_dict, decoder_inputs_dict
  194. def forward_encoder(self, feat: Tensor, feat_mask: Tensor,
  195. feat_pos: Tensor, spatial_shapes: Tensor,
  196. level_start_index: Tensor,
  197. valid_ratios: Tensor) -> Dict:
  198. """Forward with Transformer encoder.
  199. The forward procedure of the transformer is defined as:
  200. 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
  201. More details can be found at `TransformerDetector.forward_transformer`
  202. in `mmdet/detector/base_detr.py`.
  203. Args:
  204. feat (Tensor): Sequential features, has shape (bs, num_feat_points,
  205. dim).
  206. feat_mask (Tensor): ByteTensor, the padding mask of the features,
  207. has shape (bs, num_feat_points).
  208. feat_pos (Tensor): The positional embeddings of the features, has
  209. shape (bs, num_feat_points, dim).
  210. spatial_shapes (Tensor): Spatial shapes of features in all levels,
  211. has shape (num_levels, 2), last dimension represents (h, w).
  212. level_start_index (Tensor): The start index of each level.
  213. A tensor has shape (num_levels, ) and can be represented
  214. as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
  215. valid_ratios (Tensor): The ratios of the valid width and the valid
  216. height relative to the width and the height of features in all
  217. levels, has shape (bs, num_levels, 2).
  218. Returns:
  219. dict: The dictionary of encoder outputs, which includes the
  220. `memory` of the encoder output.
  221. """
  222. memory = self.encoder(
  223. query=feat,
  224. query_pos=feat_pos,
  225. key_padding_mask=feat_mask, # for self_attn
  226. spatial_shapes=spatial_shapes,
  227. level_start_index=level_start_index,
  228. valid_ratios=valid_ratios)
  229. encoder_outputs_dict = dict(
  230. memory=memory,
  231. memory_mask=feat_mask,
  232. spatial_shapes=spatial_shapes)
  233. return encoder_outputs_dict
  234. def pre_decoder(self, memory: Tensor, memory_mask: Tensor,
  235. spatial_shapes: Tensor) -> Tuple[Dict, Dict]:
  236. """Prepare intermediate variables before entering Transformer decoder,
  237. such as `query`, `query_pos`, and `reference_points`.
  238. The forward procedure of the transformer is defined as:
  239. 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
  240. More details can be found at `TransformerDetector.forward_transformer`
  241. in `mmdet/detector/base_detr.py`.
  242. Args:
  243. memory (Tensor): The output embeddings of the Transformer encoder,
  244. has shape (bs, num_feat_points, dim).
  245. memory_mask (Tensor): ByteTensor, the padding mask of the memory,
  246. has shape (bs, num_feat_points). It will only be used when
  247. `as_two_stage` is `True`.
  248. spatial_shapes (Tensor): Spatial shapes of features in all levels,
  249. has shape (num_levels, 2), last dimension represents (h, w).
  250. It will only be used when `as_two_stage` is `True`.
  251. Returns:
  252. tuple[dict, dict]: The decoder_inputs_dict and head_inputs_dict.
  253. - decoder_inputs_dict (dict): The keyword dictionary args of
  254. `self.forward_decoder()`, which includes 'query', 'query_pos',
  255. 'memory', and `reference_points`. The reference_points of
  256. decoder input here are 4D boxes when `as_two_stage` is `True`,
  257. otherwise 2D points, although it has `points` in its name.
  258. The reference_points in encoder is always 2D points.
  259. - head_inputs_dict (dict): The keyword dictionary args of the
  260. bbox_head functions, which includes `enc_outputs_class` and
  261. `enc_outputs_coord`. They are both `None` when 'as_two_stage'
  262. is `False`. The dict is empty when `self.training` is `False`.
  263. """
  264. batch_size, _, c = memory.shape
  265. if self.as_two_stage:
  266. output_memory, output_proposals = \
  267. self.gen_encoder_output_proposals(
  268. memory, memory_mask, spatial_shapes)
  269. enc_outputs_class = self.bbox_head.cls_branches[
  270. self.decoder.num_layers](
  271. output_memory)
  272. enc_outputs_coord_unact = self.bbox_head.reg_branches[
  273. self.decoder.num_layers](output_memory) + output_proposals
  274. enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
  275. # We only use the first channel in enc_outputs_class as foreground,
  276. # the other (num_classes - 1) channels are actually not used.
  277. # Its targets are set to be 0s, which indicates the first
  278. # class (foreground) because we use [0, num_classes - 1] to
  279. # indicate class labels, background class is indicated by
  280. # num_classes (similar convention in RPN).
  281. # See https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/deformable_detr_head.py#L241 # noqa
  282. # This follows the official implementation of Deformable DETR.
  283. topk_proposals = torch.topk(
  284. enc_outputs_class[..., 0], self.num_queries, dim=1)[1]
  285. topk_coords_unact = torch.gather(
  286. enc_outputs_coord_unact, 1,
  287. topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
  288. topk_coords_unact = topk_coords_unact.detach()
  289. reference_points = topk_coords_unact.sigmoid()
  290. pos_trans_out = self.pos_trans_fc(
  291. self.get_proposal_pos_embed(topk_coords_unact))
  292. pos_trans_out = self.pos_trans_norm(pos_trans_out)
  293. query_pos, query = torch.split(pos_trans_out, c, dim=2)
  294. else:
  295. enc_outputs_class, enc_outputs_coord = None, None
  296. query_embed = self.query_embedding.weight
  297. query_pos, query = torch.split(query_embed, c, dim=1)
  298. query_pos = query_pos.unsqueeze(0).expand(batch_size, -1, -1)
  299. query = query.unsqueeze(0).expand(batch_size, -1, -1)
  300. reference_points = self.reference_points_fc(query_pos).sigmoid()
  301. decoder_inputs_dict = dict(
  302. query=query,
  303. query_pos=query_pos,
  304. memory=memory,
  305. reference_points=reference_points)
  306. head_inputs_dict = dict(
  307. enc_outputs_class=enc_outputs_class,
  308. enc_outputs_coord=enc_outputs_coord) if self.training else dict()
  309. return decoder_inputs_dict, head_inputs_dict
  310. def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor,
  311. memory_mask: Tensor, reference_points: Tensor,
  312. spatial_shapes: Tensor, level_start_index: Tensor,
  313. valid_ratios: Tensor) -> Dict:
  314. """Forward with Transformer decoder.
  315. The forward procedure of the transformer is defined as:
  316. 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
  317. More details can be found at `TransformerDetector.forward_transformer`
  318. in `mmdet/detector/base_detr.py`.
  319. Args:
  320. query (Tensor): The queries of decoder inputs, has shape
  321. (bs, num_queries, dim).
  322. query_pos (Tensor): The positional queries of decoder inputs,
  323. has shape (bs, num_queries, dim).
  324. memory (Tensor): The output embeddings of the Transformer encoder,
  325. has shape (bs, num_feat_points, dim).
  326. memory_mask (Tensor): ByteTensor, the padding mask of the memory,
  327. has shape (bs, num_feat_points).
  328. reference_points (Tensor): The initial reference, has shape
  329. (bs, num_queries, 4) with the last dimension arranged as
  330. (cx, cy, w, h) when `as_two_stage` is `True`, otherwise has
  331. shape (bs, num_queries, 2) with the last dimension arranged as
  332. (cx, cy).
  333. spatial_shapes (Tensor): Spatial shapes of features in all levels,
  334. has shape (num_levels, 2), last dimension represents (h, w).
  335. level_start_index (Tensor): The start index of each level.
  336. A tensor has shape (num_levels, ) and can be represented
  337. as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
  338. valid_ratios (Tensor): The ratios of the valid width and the valid
  339. height relative to the width and the height of features in all
  340. levels, has shape (bs, num_levels, 2).
  341. Returns:
  342. dict: The dictionary of decoder outputs, which includes the
  343. `hidden_states` of the decoder output and `references` including
  344. the initial and intermediate reference_points.
  345. """
  346. inter_states, inter_references = self.decoder(
  347. query=query,
  348. value=memory,
  349. query_pos=query_pos,
  350. key_padding_mask=memory_mask, # for cross_attn
  351. reference_points=reference_points,
  352. spatial_shapes=spatial_shapes,
  353. level_start_index=level_start_index,
  354. valid_ratios=valid_ratios,
  355. reg_branches=self.bbox_head.reg_branches
  356. if self.with_box_refine else None)
  357. references = [reference_points, *inter_references]
  358. decoder_outputs_dict = dict(
  359. hidden_states=inter_states, references=references)
  360. return decoder_outputs_dict
  361. @staticmethod
  362. def get_valid_ratio(mask: Tensor) -> Tensor:
  363. """Get the valid radios of feature map in a level.
  364. .. code:: text
  365. |---> valid_W <---|
  366. ---+-----------------+-----+---
  367. A | | | A
  368. | | | | |
  369. | | | | |
  370. valid_H | | | |
  371. | | | | H
  372. | | | | |
  373. V | | | |
  374. ---+-----------------+ | |
  375. | | V
  376. +-----------------------+---
  377. |---------> W <---------|
  378. The valid_ratios are defined as:
  379. r_h = valid_H / H, r_w = valid_W / W
  380. They are the factors to re-normalize the relative coordinates of the
  381. image to the relative coordinates of the current level feature map.
  382. Args:
  383. mask (Tensor): Binary mask of a feature map, has shape (bs, H, W).
  384. Returns:
  385. Tensor: valid ratios [r_w, r_h] of a feature map, has shape (1, 2).
  386. """
  387. _, H, W = mask.shape
  388. valid_H = torch.sum(~mask[:, :, 0], 1)
  389. valid_W = torch.sum(~mask[:, 0, :], 1)
  390. valid_ratio_h = valid_H.float() / H
  391. valid_ratio_w = valid_W.float() / W
  392. valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
  393. return valid_ratio
  394. def gen_encoder_output_proposals(
  395. self, memory: Tensor, memory_mask: Tensor,
  396. spatial_shapes: Tensor) -> Tuple[Tensor, Tensor]:
  397. """Generate proposals from encoded memory. The function will only be
  398. used when `as_two_stage` is `True`.
  399. Args:
  400. memory (Tensor): The output embeddings of the Transformer encoder,
  401. has shape (bs, num_feat_points, dim).
  402. memory_mask (Tensor): ByteTensor, the padding mask of the memory,
  403. has shape (bs, num_feat_points).
  404. spatial_shapes (Tensor): Spatial shapes of features in all levels,
  405. has shape (num_levels, 2), last dimension represents (h, w).
  406. Returns:
  407. tuple: A tuple of transformed memory and proposals.
  408. - output_memory (Tensor): The transformed memory for obtaining
  409. top-k proposals, has shape (bs, num_feat_points, dim).
  410. - output_proposals (Tensor): The inverse-normalized proposal, has
  411. shape (batch_size, num_keys, 4) with the last dimension arranged
  412. as (cx, cy, w, h).
  413. """
  414. bs = memory.size(0)
  415. proposals = []
  416. _cur = 0 # start index in the sequence of the current level
  417. for lvl, (H, W) in enumerate(spatial_shapes):
  418. mask_flatten_ = memory_mask[:,
  419. _cur:(_cur + H * W)].view(bs, H, W, 1)
  420. valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1).unsqueeze(-1)
  421. valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1).unsqueeze(-1)
  422. grid_y, grid_x = torch.meshgrid(
  423. torch.linspace(
  424. 0, H - 1, H, dtype=torch.float32, device=memory.device),
  425. torch.linspace(
  426. 0, W - 1, W, dtype=torch.float32, device=memory.device))
  427. grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
  428. scale = torch.cat([valid_W, valid_H], 1).view(bs, 1, 1, 2)
  429. grid = (grid.unsqueeze(0).expand(bs, -1, -1, -1) + 0.5) / scale
  430. wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
  431. proposal = torch.cat((grid, wh), -1).view(bs, -1, 4)
  432. proposals.append(proposal)
  433. _cur += (H * W)
  434. output_proposals = torch.cat(proposals, 1)
  435. output_proposals_valid = ((output_proposals > 0.01) &
  436. (output_proposals < 0.99)).all(
  437. -1, keepdim=True)
  438. # inverse_sigmoid
  439. output_proposals = torch.log(output_proposals / (1 - output_proposals))
  440. output_proposals = output_proposals.masked_fill(
  441. memory_mask.unsqueeze(-1), float('inf'))
  442. output_proposals = output_proposals.masked_fill(
  443. ~output_proposals_valid, float('inf'))
  444. output_memory = memory
  445. output_memory = output_memory.masked_fill(
  446. memory_mask.unsqueeze(-1), float(0))
  447. output_memory = output_memory.masked_fill(~output_proposals_valid,
  448. float(0))
  449. output_memory = self.memory_trans_fc(output_memory)
  450. output_memory = self.memory_trans_norm(output_memory)
  451. # [bs, sum(hw), 2]
  452. return output_memory, output_proposals
  453. @staticmethod
  454. def get_proposal_pos_embed(proposals: Tensor,
  455. num_pos_feats: int = 128,
  456. temperature: int = 10000) -> Tensor:
  457. """Get the position embedding of the proposal.
  458. Args:
  459. proposals (Tensor): Not normalized proposals, has shape
  460. (bs, num_queries, 4) with the last dimension arranged as
  461. (cx, cy, w, h).
  462. num_pos_feats (int, optional): The feature dimension for each
  463. position along x, y, w, and h-axis. Note the final returned
  464. dimension for each position is 4 times of num_pos_feats.
  465. Default to 128.
  466. temperature (int, optional): The temperature used for scaling the
  467. position embedding. Defaults to 10000.
  468. Returns:
  469. Tensor: The position embedding of proposal, has shape
  470. (bs, num_queries, num_pos_feats * 4), with the last dimension
  471. arranged as (cx, cy, w, h)
  472. """
  473. scale = 2 * math.pi
  474. dim_t = torch.arange(
  475. num_pos_feats, dtype=torch.float32, device=proposals.device)
  476. dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats)
  477. # N, L, 4
  478. proposals = proposals.sigmoid() * scale
  479. # N, L, 4, 128
  480. pos = proposals[:, :, :, None] / dim_t
  481. # N, L, 4, 64, 2
  482. pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()),
  483. dim=4).flatten(2)
  484. return pos