anchor_generator.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Tuple, Union
  3. import torch
  4. from torch import Tensor
  5. from mmdet.models.task_modules.prior_generators.anchor_generator import \
  6. AnchorGenerator
  7. from mmdet.registry import TASK_UTILS
  8. from mmdet.structures.bbox import HorizontalBoxes
  9. DeviceType = Union[str, torch.device]
  10. @TASK_UTILS.register_module()
  11. class YXYXAnchorGenerator(AnchorGenerator):
  12. def gen_single_level_base_anchors(self,
  13. base_size: Union[int, float],
  14. scales: Tensor,
  15. ratios: Tensor,
  16. center: Optional[Tuple[float]] = None) \
  17. -> Tensor:
  18. """Generate base anchors of a single level.
  19. Args:
  20. base_size (int | float): Basic size of an anchor.
  21. scales (torch.Tensor): Scales of the anchor.
  22. ratios (torch.Tensor): The ratio between the height
  23. and width of anchors in a single level.
  24. center (tuple[float], optional): The center of the base anchor
  25. related to a single feature grid. Defaults to None.
  26. Returns:
  27. torch.Tensor: Anchors in a single-level feature maps.
  28. """
  29. w = base_size
  30. h = base_size
  31. if center is None:
  32. x_center = self.center_offset * w
  33. y_center = self.center_offset * h
  34. else:
  35. x_center, y_center = center
  36. h_ratios = torch.sqrt(ratios)
  37. w_ratios = 1 / h_ratios
  38. if self.scale_major:
  39. ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
  40. hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)
  41. else:
  42. ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
  43. hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)
  44. # use float anchor and the anchor's center is aligned with the
  45. # pixel center
  46. base_anchors = [
  47. y_center - 0.5 * hs,
  48. x_center - 0.5 * ws,
  49. y_center + 0.5 * hs,
  50. x_center + 0.5 * ws,
  51. ]
  52. base_anchors = torch.stack(base_anchors, dim=-1)
  53. return base_anchors
  54. def single_level_grid_priors(self,
  55. featmap_size: Tuple[int, int],
  56. level_idx: int,
  57. dtype: torch.dtype = torch.float32,
  58. device: DeviceType = 'cuda') -> Tensor:
  59. """Generate grid anchors of a single level.
  60. Note:
  61. This function is usually called by method ``self.grid_priors``.
  62. Args:
  63. featmap_size (tuple[int, int]): Size of the feature maps.
  64. level_idx (int): The index of corresponding feature map level.
  65. dtype (obj:`torch.dtype`): Date type of points.Defaults to
  66. ``torch.float32``.
  67. device (str | torch.device): The device the tensor will be put on.
  68. Defaults to 'cuda'.
  69. Returns:
  70. torch.Tensor: Anchors in the overall feature maps.
  71. """
  72. base_anchors = self.base_anchors[level_idx].to(device).to(dtype)
  73. feat_h, feat_w = featmap_size
  74. stride_w, stride_h = self.strides[level_idx]
  75. # First create Range with the default dtype, than convert to
  76. # target `dtype` for onnx exporting.
  77. shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w
  78. shift_y = torch.arange(0, feat_h, device=device).to(dtype) * stride_h
  79. shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
  80. shifts = torch.stack([shift_yy, shift_xx, shift_yy, shift_xx], dim=-1)
  81. # first feat_w elements correspond to the first row of shifts
  82. # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
  83. # shifted anchors (K, A, 4), reshape to (K*A, 4)
  84. all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
  85. all_anchors = all_anchors.view(-1, 4)
  86. # first A rows correspond to A anchors of (0, 0) in feature map,
  87. # then (0, 1), (0, 2), ...
  88. if self.use_box_type:
  89. all_anchors = HorizontalBoxes(all_anchors)
  90. return all_anchors