misc.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from functools import partial
  3. from typing import List, Sequence, Tuple, Union
  4. import numpy as np
  5. import torch
  6. from mmengine.structures import InstanceData
  7. from mmengine.utils import digit_version
  8. from six.moves import map, zip
  9. from torch import Tensor
  10. from torch.autograd import Function
  11. from torch.nn import functional as F
  12. from mmdet.structures import SampleList
  13. from mmdet.structures.bbox import BaseBoxes, get_box_type, stack_boxes
  14. from mmdet.structures.mask import BitmapMasks, PolygonMasks
  15. from mmdet.utils import OptInstanceList
  16. class SigmoidGeometricMean(Function):
  17. """Forward and backward function of geometric mean of two sigmoid
  18. functions.
  19. This implementation with analytical gradient function substitutes
  20. the autograd function of (x.sigmoid() * y.sigmoid()).sqrt(). The
  21. original implementation incurs none during gradient backprapagation
  22. if both x and y are very small values.
  23. """
  24. @staticmethod
  25. def forward(ctx, x, y):
  26. x_sigmoid = x.sigmoid()
  27. y_sigmoid = y.sigmoid()
  28. z = (x_sigmoid * y_sigmoid).sqrt()
  29. ctx.save_for_backward(x_sigmoid, y_sigmoid, z)
  30. return z
  31. @staticmethod
  32. def backward(ctx, grad_output):
  33. x_sigmoid, y_sigmoid, z = ctx.saved_tensors
  34. grad_x = grad_output * z * (1 - x_sigmoid) / 2
  35. grad_y = grad_output * z * (1 - y_sigmoid) / 2
  36. return grad_x, grad_y
  37. sigmoid_geometric_mean = SigmoidGeometricMean.apply
  38. def interpolate_as(source, target, mode='bilinear', align_corners=False):
  39. """Interpolate the `source` to the shape of the `target`.
  40. The `source` must be a Tensor, but the `target` can be a Tensor or a
  41. np.ndarray with the shape (..., target_h, target_w).
  42. Args:
  43. source (Tensor): A 3D/4D Tensor with the shape (N, H, W) or
  44. (N, C, H, W).
  45. target (Tensor | np.ndarray): The interpolation target with the shape
  46. (..., target_h, target_w).
  47. mode (str): Algorithm used for interpolation. The options are the
  48. same as those in F.interpolate(). Default: ``'bilinear'``.
  49. align_corners (bool): The same as the argument in F.interpolate().
  50. Returns:
  51. Tensor: The interpolated source Tensor.
  52. """
  53. assert len(target.shape) >= 2
  54. def _interpolate_as(source, target, mode='bilinear', align_corners=False):
  55. """Interpolate the `source` (4D) to the shape of the `target`."""
  56. target_h, target_w = target.shape[-2:]
  57. source_h, source_w = source.shape[-2:]
  58. if target_h != source_h or target_w != source_w:
  59. source = F.interpolate(
  60. source,
  61. size=(target_h, target_w),
  62. mode=mode,
  63. align_corners=align_corners)
  64. return source
  65. if len(source.shape) == 3:
  66. source = source[:, None, :, :]
  67. source = _interpolate_as(source, target, mode, align_corners)
  68. return source[:, 0, :, :]
  69. else:
  70. return _interpolate_as(source, target, mode, align_corners)
  71. def unpack_gt_instances(batch_data_samples: SampleList) -> tuple:
  72. """Unpack ``gt_instances``, ``gt_instances_ignore`` and ``img_metas`` based
  73. on ``batch_data_samples``
  74. Args:
  75. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  76. Samples. It usually includes information such as
  77. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  78. Returns:
  79. tuple:
  80. - batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  81. gt_instance. It usually includes ``bboxes`` and ``labels``
  82. attributes.
  83. - batch_gt_instances_ignore (list[:obj:`InstanceData`]):
  84. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  85. data that is ignored during training and testing.
  86. Defaults to None.
  87. - batch_img_metas (list[dict]): Meta information of each image,
  88. e.g., image size, scaling factor, etc.
  89. """
  90. batch_gt_instances = []
  91. batch_gt_instances_ignore = []
  92. batch_img_metas = []
  93. for data_sample in batch_data_samples:
  94. batch_img_metas.append(data_sample.metainfo)
  95. batch_gt_instances.append(data_sample.gt_instances)
  96. if 'ignored_instances' in data_sample:
  97. batch_gt_instances_ignore.append(data_sample.ignored_instances)
  98. else:
  99. batch_gt_instances_ignore.append(None)
  100. return batch_gt_instances, batch_gt_instances_ignore, batch_img_metas
  101. def empty_instances(batch_img_metas: List[dict],
  102. device: torch.device,
  103. task_type: str,
  104. instance_results: OptInstanceList = None,
  105. mask_thr_binary: Union[int, float] = 0,
  106. box_type: Union[str, type] = 'hbox',
  107. use_box_type: bool = False,
  108. num_classes: int = 80,
  109. score_per_cls: bool = False) -> List[InstanceData]:
  110. """Handle predicted instances when RoI is empty.
  111. Note: If ``instance_results`` is not None, it will be modified
  112. in place internally, and then return ``instance_results``
  113. Args:
  114. batch_img_metas (list[dict]): List of image information.
  115. device (torch.device): Device of tensor.
  116. task_type (str): Expected returned task type. it currently
  117. supports bbox and mask.
  118. instance_results (list[:obj:`InstanceData`]): List of instance
  119. results.
  120. mask_thr_binary (int, float): mask binarization threshold.
  121. Defaults to 0.
  122. box_type (str or type): The empty box type. Defaults to `hbox`.
  123. use_box_type (bool): Whether to warp boxes with the box type.
  124. Defaults to False.
  125. num_classes (int): num_classes of bbox_head. Defaults to 80.
  126. score_per_cls (bool): Whether to generate classwise score for
  127. the empty instance. ``score_per_cls`` will be True when the model
  128. needs to produce raw results without nms. Defaults to False.
  129. Returns:
  130. list[:obj:`InstanceData`]: Detection results of each image
  131. """
  132. assert task_type in ('bbox', 'mask'), 'Only support bbox and mask,' \
  133. f' but got {task_type}'
  134. if instance_results is not None:
  135. assert len(instance_results) == len(batch_img_metas)
  136. results_list = []
  137. for img_id in range(len(batch_img_metas)):
  138. if instance_results is not None:
  139. results = instance_results[img_id]
  140. assert isinstance(results, InstanceData)
  141. else:
  142. results = InstanceData()
  143. if task_type == 'bbox':
  144. _, box_type = get_box_type(box_type)
  145. bboxes = torch.zeros(0, box_type.box_dim, device=device)
  146. if use_box_type:
  147. bboxes = box_type(bboxes, clone=False)
  148. results.bboxes = bboxes
  149. score_shape = (0, num_classes + 1) if score_per_cls else (0, )
  150. results.scores = torch.zeros(score_shape, device=device)
  151. results.labels = torch.zeros((0, ),
  152. device=device,
  153. dtype=torch.long)
  154. else:
  155. # TODO: Handle the case where rescale is false
  156. img_h, img_w = batch_img_metas[img_id]['ori_shape'][:2]
  157. # the type of `im_mask` will be torch.bool or torch.uint8,
  158. # where uint8 if for visualization and debugging.
  159. im_mask = torch.zeros(
  160. 0,
  161. img_h,
  162. img_w,
  163. device=device,
  164. dtype=torch.bool if mask_thr_binary >= 0 else torch.uint8)
  165. results.masks = im_mask
  166. results_list.append(results)
  167. return results_list
  168. def multi_apply(func, *args, **kwargs):
  169. """Apply function to a list of arguments.
  170. Note:
  171. This function applies the ``func`` to multiple inputs and
  172. map the multiple outputs of the ``func`` into different
  173. list. Each list contains the same type of outputs corresponding
  174. to different inputs.
  175. Args:
  176. func (Function): A function that will be applied to a list of
  177. arguments
  178. Returns:
  179. tuple(list): A tuple containing multiple list, each list contains \
  180. a kind of returned results by the function
  181. """
  182. pfunc = partial(func, **kwargs) if kwargs else func
  183. map_results = map(pfunc, *args)
  184. return tuple(map(list, zip(*map_results)))
  185. def unmap(data, count, inds, fill=0):
  186. """Unmap a subset of item (data) back to the original set of items (of size
  187. count)"""
  188. if data.dim() == 1:
  189. ret = data.new_full((count, ), fill)
  190. ret[inds.type(torch.bool)] = data
  191. else:
  192. new_size = (count, ) + data.size()[1:]
  193. ret = data.new_full(new_size, fill)
  194. ret[inds.type(torch.bool), :] = data
  195. return ret
  196. def mask2ndarray(mask):
  197. """Convert Mask to ndarray..
  198. Args:
  199. mask (:obj:`BitmapMasks` or :obj:`PolygonMasks` or
  200. torch.Tensor or np.ndarray): The mask to be converted.
  201. Returns:
  202. np.ndarray: Ndarray mask of shape (n, h, w) that has been converted
  203. """
  204. if isinstance(mask, (BitmapMasks, PolygonMasks)):
  205. mask = mask.to_ndarray()
  206. elif isinstance(mask, torch.Tensor):
  207. mask = mask.detach().cpu().numpy()
  208. elif not isinstance(mask, np.ndarray):
  209. raise TypeError(f'Unsupported {type(mask)} data type')
  210. return mask
  211. def flip_tensor(src_tensor, flip_direction):
  212. """flip tensor base on flip_direction.
  213. Args:
  214. src_tensor (Tensor): input feature map, shape (B, C, H, W).
  215. flip_direction (str): The flipping direction. Options are
  216. 'horizontal', 'vertical', 'diagonal'.
  217. Returns:
  218. out_tensor (Tensor): Flipped tensor.
  219. """
  220. assert src_tensor.ndim == 4
  221. valid_directions = ['horizontal', 'vertical', 'diagonal']
  222. assert flip_direction in valid_directions
  223. if flip_direction == 'horizontal':
  224. out_tensor = torch.flip(src_tensor, [3])
  225. elif flip_direction == 'vertical':
  226. out_tensor = torch.flip(src_tensor, [2])
  227. else:
  228. out_tensor = torch.flip(src_tensor, [2, 3])
  229. return out_tensor
  230. def select_single_mlvl(mlvl_tensors, batch_id, detach=True):
  231. """Extract a multi-scale single image tensor from a multi-scale batch
  232. tensor based on batch index.
  233. Note: The default value of detach is True, because the proposal gradient
  234. needs to be detached during the training of the two-stage model. E.g
  235. Cascade Mask R-CNN.
  236. Args:
  237. mlvl_tensors (list[Tensor]): Batch tensor for all scale levels,
  238. each is a 4D-tensor.
  239. batch_id (int): Batch index.
  240. detach (bool): Whether detach gradient. Default True.
  241. Returns:
  242. list[Tensor]: Multi-scale single image tensor.
  243. """
  244. assert isinstance(mlvl_tensors, (list, tuple))
  245. num_levels = len(mlvl_tensors)
  246. if detach:
  247. mlvl_tensor_list = [
  248. mlvl_tensors[i][batch_id].detach() for i in range(num_levels)
  249. ]
  250. else:
  251. mlvl_tensor_list = [
  252. mlvl_tensors[i][batch_id] for i in range(num_levels)
  253. ]
  254. return mlvl_tensor_list
  255. def filter_scores_and_topk(scores, score_thr, topk, results=None):
  256. """Filter results using score threshold and topk candidates.
  257. Args:
  258. scores (Tensor): The scores, shape (num_bboxes, K).
  259. score_thr (float): The score filter threshold.
  260. topk (int): The number of topk candidates.
  261. results (dict or list or Tensor, Optional): The results to
  262. which the filtering rule is to be applied. The shape
  263. of each item is (num_bboxes, N).
  264. Returns:
  265. tuple: Filtered results
  266. - scores (Tensor): The scores after being filtered, \
  267. shape (num_bboxes_filtered, ).
  268. - labels (Tensor): The class labels, shape \
  269. (num_bboxes_filtered, ).
  270. - anchor_idxs (Tensor): The anchor indexes, shape \
  271. (num_bboxes_filtered, ).
  272. - filtered_results (dict or list or Tensor, Optional): \
  273. The filtered results. The shape of each item is \
  274. (num_bboxes_filtered, N).
  275. """
  276. valid_mask = scores > score_thr
  277. scores = scores[valid_mask]
  278. valid_idxs = torch.nonzero(valid_mask)
  279. num_topk = min(topk, valid_idxs.size(0))
  280. # torch.sort is actually faster than .topk (at least on GPUs)
  281. scores, idxs = scores.sort(descending=True)
  282. scores = scores[:num_topk]
  283. topk_idxs = valid_idxs[idxs[:num_topk]]
  284. keep_idxs, labels = topk_idxs.unbind(dim=1)
  285. filtered_results = None
  286. if results is not None:
  287. if isinstance(results, dict):
  288. filtered_results = {k: v[keep_idxs] for k, v in results.items()}
  289. elif isinstance(results, list):
  290. filtered_results = [result[keep_idxs] for result in results]
  291. elif isinstance(results, torch.Tensor):
  292. filtered_results = results[keep_idxs]
  293. else:
  294. raise NotImplementedError(f'Only supports dict or list or Tensor, '
  295. f'but get {type(results)}.')
  296. return scores, labels, keep_idxs, filtered_results
  297. def center_of_mass(mask, esp=1e-6):
  298. """Calculate the centroid coordinates of the mask.
  299. Args:
  300. mask (Tensor): The mask to be calculated, shape (h, w).
  301. esp (float): Avoid dividing by zero. Default: 1e-6.
  302. Returns:
  303. tuple[Tensor]: the coordinates of the center point of the mask.
  304. - center_h (Tensor): the center point of the height.
  305. - center_w (Tensor): the center point of the width.
  306. """
  307. h, w = mask.shape
  308. grid_h = torch.arange(h, device=mask.device)[:, None]
  309. grid_w = torch.arange(w, device=mask.device)
  310. normalizer = mask.sum().float().clamp(min=esp)
  311. center_h = (mask * grid_h).sum() / normalizer
  312. center_w = (mask * grid_w).sum() / normalizer
  313. return center_h, center_w
  314. def generate_coordinate(featmap_sizes, device='cuda'):
  315. """Generate the coordinate.
  316. Args:
  317. featmap_sizes (tuple): The feature to be calculated,
  318. of shape (N, C, W, H).
  319. device (str): The device where the feature will be put on.
  320. Returns:
  321. coord_feat (Tensor): The coordinate feature, of shape (N, 2, W, H).
  322. """
  323. x_range = torch.linspace(-1, 1, featmap_sizes[-1], device=device)
  324. y_range = torch.linspace(-1, 1, featmap_sizes[-2], device=device)
  325. y, x = torch.meshgrid(y_range, x_range)
  326. y = y.expand([featmap_sizes[0], 1, -1, -1])
  327. x = x.expand([featmap_sizes[0], 1, -1, -1])
  328. coord_feat = torch.cat([x, y], 1)
  329. return coord_feat
  330. def levels_to_images(mlvl_tensor: List[torch.Tensor]) -> List[torch.Tensor]:
  331. """Concat multi-level feature maps by image.
  332. [feature_level0, feature_level1...] -> [feature_image0, feature_image1...]
  333. Convert the shape of each element in mlvl_tensor from (N, C, H, W) to
  334. (N, H*W , C), then split the element to N elements with shape (H*W, C), and
  335. concat elements in same image of all level along first dimension.
  336. Args:
  337. mlvl_tensor (list[Tensor]): list of Tensor which collect from
  338. corresponding level. Each element is of shape (N, C, H, W)
  339. Returns:
  340. list[Tensor]: A list that contains N tensors and each tensor is
  341. of shape (num_elements, C)
  342. """
  343. batch_size = mlvl_tensor[0].size(0)
  344. batch_list = [[] for _ in range(batch_size)]
  345. channels = mlvl_tensor[0].size(1)
  346. for t in mlvl_tensor:
  347. t = t.permute(0, 2, 3, 1)
  348. t = t.view(batch_size, -1, channels).contiguous()
  349. for img in range(batch_size):
  350. batch_list[img].append(t[img])
  351. return [torch.cat(item, 0) for item in batch_list]
  352. def images_to_levels(target, num_levels):
  353. """Convert targets by image to targets by feature level.
  354. [target_img0, target_img1] -> [target_level0, target_level1, ...]
  355. """
  356. target = stack_boxes(target, 0)
  357. level_targets = []
  358. start = 0
  359. for n in num_levels:
  360. end = start + n
  361. # level_targets.append(target[:, start:end].squeeze(0))
  362. level_targets.append(target[:, start:end])
  363. start = end
  364. return level_targets
  365. def samplelist_boxtype2tensor(batch_data_samples: SampleList) -> SampleList:
  366. for data_samples in batch_data_samples:
  367. if 'gt_instances' in data_samples:
  368. bboxes = data_samples.gt_instances.get('bboxes', None)
  369. if isinstance(bboxes, BaseBoxes):
  370. data_samples.gt_instances.bboxes = bboxes.tensor
  371. if 'pred_instances' in data_samples:
  372. bboxes = data_samples.pred_instances.get('bboxes', None)
  373. if isinstance(bboxes, BaseBoxes):
  374. data_samples.pred_instances.bboxes = bboxes.tensor
  375. if 'ignored_instances' in data_samples:
  376. bboxes = data_samples.ignored_instances.get('bboxes', None)
  377. if isinstance(bboxes, BaseBoxes):
  378. data_samples.ignored_instances.bboxes = bboxes.tensor
  379. _torch_version_div_indexing = (
  380. 'parrots' not in torch.__version__
  381. and digit_version(torch.__version__) >= digit_version('1.8'))
  382. def floordiv(dividend, divisor, rounding_mode='trunc'):
  383. if _torch_version_div_indexing:
  384. return torch.div(dividend, divisor, rounding_mode=rounding_mode)
  385. else:
  386. return dividend // divisor
  387. def _filter_gt_instances_by_score(batch_data_samples: SampleList,
  388. score_thr: float) -> SampleList:
  389. """Filter ground truth (GT) instances by score.
  390. Args:
  391. batch_data_samples (SampleList): The Data
  392. Samples. It usually includes information such as
  393. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  394. score_thr (float): The score filter threshold.
  395. Returns:
  396. SampleList: The Data Samples filtered by score.
  397. """
  398. for data_samples in batch_data_samples:
  399. assert 'scores' in data_samples.gt_instances, \
  400. 'there does not exit scores in instances'
  401. if data_samples.gt_instances.bboxes.shape[0] > 0:
  402. data_samples.gt_instances = data_samples.gt_instances[
  403. data_samples.gt_instances.scores > score_thr]
  404. return batch_data_samples
  405. def _filter_gt_instances_by_size(batch_data_samples: SampleList,
  406. wh_thr: tuple) -> SampleList:
  407. """Filter ground truth (GT) instances by size.
  408. Args:
  409. batch_data_samples (SampleList): The Data
  410. Samples. It usually includes information such as
  411. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  412. wh_thr (tuple): Minimum width and height of bbox.
  413. Returns:
  414. SampleList: The Data Samples filtered by score.
  415. """
  416. for data_samples in batch_data_samples:
  417. bboxes = data_samples.gt_instances.bboxes
  418. if bboxes.shape[0] > 0:
  419. w = bboxes[:, 2] - bboxes[:, 0]
  420. h = bboxes[:, 3] - bboxes[:, 1]
  421. data_samples.gt_instances = data_samples.gt_instances[
  422. (w > wh_thr[0]) & (h > wh_thr[1])]
  423. return batch_data_samples
  424. def filter_gt_instances(batch_data_samples: SampleList,
  425. score_thr: float = None,
  426. wh_thr: tuple = None):
  427. """Filter ground truth (GT) instances by score and/or size.
  428. Args:
  429. batch_data_samples (SampleList): The Data
  430. Samples. It usually includes information such as
  431. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  432. score_thr (float): The score filter threshold.
  433. wh_thr (tuple): Minimum width and height of bbox.
  434. Returns:
  435. SampleList: The Data Samples filtered by score and/or size.
  436. """
  437. if score_thr is not None:
  438. batch_data_samples = _filter_gt_instances_by_score(
  439. batch_data_samples, score_thr)
  440. if wh_thr is not None:
  441. batch_data_samples = _filter_gt_instances_by_size(
  442. batch_data_samples, wh_thr)
  443. return batch_data_samples
  444. def rename_loss_dict(prefix: str, losses: dict) -> dict:
  445. """Rename the key names in loss dict by adding a prefix.
  446. Args:
  447. prefix (str): The prefix for loss components.
  448. losses (dict): A dictionary of loss components.
  449. Returns:
  450. dict: A dictionary of loss components with prefix.
  451. """
  452. return {prefix + k: v for k, v in losses.items()}
  453. def reweight_loss_dict(losses: dict, weight: float) -> dict:
  454. """Reweight losses in the dict by weight.
  455. Args:
  456. losses (dict): A dictionary of loss components.
  457. weight (float): Weight for loss components.
  458. Returns:
  459. dict: A dictionary of weighted loss components.
  460. """
  461. for name, loss in losses.items():
  462. if 'loss' in name:
  463. if isinstance(loss, Sequence):
  464. losses[name] = [item * weight for item in loss]
  465. else:
  466. losses[name] = loss * weight
  467. return losses
  468. def relative_coordinate_maps(
  469. locations: Tensor,
  470. centers: Tensor,
  471. strides: Tensor,
  472. size_of_interest: int,
  473. feat_sizes: Tuple[int],
  474. ) -> Tensor:
  475. """Generate the relative coordinate maps with feat_stride.
  476. Args:
  477. locations (Tensor): The prior location of mask feature map.
  478. It has shape (num_priors, 2).
  479. centers (Tensor): The prior points of a object in
  480. all feature pyramid. It has shape (num_pos, 2)
  481. strides (Tensor): The prior strides of a object in
  482. all feature pyramid. It has shape (num_pos, 1)
  483. size_of_interest (int): The size of the region used in rel coord.
  484. feat_sizes (Tuple[int]): The feature size H and W, which has 2 dims.
  485. Returns:
  486. rel_coord_feat (Tensor): The coordinate feature
  487. of shape (num_pos, 2, H, W).
  488. """
  489. H, W = feat_sizes
  490. rel_coordinates = centers.reshape(-1, 1, 2) - locations.reshape(1, -1, 2)
  491. rel_coordinates = rel_coordinates.permute(0, 2, 1).float()
  492. rel_coordinates = rel_coordinates / (
  493. strides[:, None, None] * size_of_interest)
  494. return rel_coordinates.reshape(-1, 2, H, W)
  495. def aligned_bilinear(tensor: Tensor, factor: int) -> Tensor:
  496. """aligned bilinear, used in original implement in CondInst:
  497. https://github.com/aim-uofa/AdelaiDet/blob/\
  498. c0b2092ce72442b0f40972f7c6dda8bb52c46d16/adet/utils/comm.py#L23
  499. """
  500. assert tensor.dim() == 4
  501. assert factor >= 1
  502. assert int(factor) == factor
  503. if factor == 1:
  504. return tensor
  505. h, w = tensor.size()[2:]
  506. tensor = F.pad(tensor, pad=(0, 1, 0, 1), mode='replicate')
  507. oh = factor * h + 1
  508. ow = factor * w + 1
  509. tensor = F.interpolate(
  510. tensor, size=(oh, ow), mode='bilinear', align_corners=True)
  511. tensor = F.pad(
  512. tensor, pad=(factor // 2, 0, factor // 2, 0), mode='replicate')
  513. return tensor[:, :, :oh - 1, :ow - 1]
  514. def unfold_wo_center(x, kernel_size: int, dilation: int) -> Tensor:
  515. """unfold_wo_center, used in original implement in BoxInst:
  516. https://github.com/aim-uofa/AdelaiDet/blob/\
  517. 4a3a1f7372c35b48ebf5f6adc59f135a0fa28d60/\
  518. adet/modeling/condinst/condinst.py#L53
  519. """
  520. assert x.dim() == 4
  521. assert kernel_size % 2 == 1
  522. # using SAME padding
  523. padding = (kernel_size + (dilation - 1) * (kernel_size - 1)) // 2
  524. unfolded_x = F.unfold(
  525. x, kernel_size=kernel_size, padding=padding, dilation=dilation)
  526. unfolded_x = unfolded_x.reshape(
  527. x.size(0), x.size(1), -1, x.size(2), x.size(3))
  528. # remove the center pixels
  529. size = kernel_size**2
  530. unfolded_x = torch.cat(
  531. (unfolded_x[:, :, :size // 2], unfolded_x[:, :, size // 2 + 1:]),
  532. dim=2)
  533. return unfolded_x