utils.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Tuple
  3. import torch
  4. from torch import Tensor
  5. from mmdet.structures.bbox import BaseBoxes
  6. def anchor_inside_flags(flat_anchors: Tensor,
  7. valid_flags: Tensor,
  8. img_shape: Tuple[int],
  9. allowed_border: int = 0) -> Tensor:
  10. """Check whether the anchors are inside the border.
  11. Args:
  12. flat_anchors (torch.Tensor): Flatten anchors, shape (n, 4).
  13. valid_flags (torch.Tensor): An existing valid flags of anchors.
  14. img_shape (tuple(int)): Shape of current image.
  15. allowed_border (int): The border to allow the valid anchor.
  16. Defaults to 0.
  17. Returns:
  18. torch.Tensor: Flags indicating whether the anchors are inside a \
  19. valid range.
  20. """
  21. img_h, img_w = img_shape[:2]
  22. if allowed_border >= 0:
  23. if isinstance(flat_anchors, BaseBoxes):
  24. inside_flags = valid_flags & \
  25. flat_anchors.is_inside([img_h, img_w],
  26. all_inside=True,
  27. allowed_border=allowed_border)
  28. else:
  29. inside_flags = valid_flags & \
  30. (flat_anchors[:, 0] >= -allowed_border) & \
  31. (flat_anchors[:, 1] >= -allowed_border) & \
  32. (flat_anchors[:, 2] < img_w + allowed_border) & \
  33. (flat_anchors[:, 3] < img_h + allowed_border)
  34. else:
  35. inside_flags = valid_flags
  36. return inside_flags
  37. def calc_region(bbox: Tensor,
  38. ratio: float,
  39. featmap_size: Optional[Tuple] = None) -> Tuple[int]:
  40. """Calculate a proportional bbox region.
  41. The bbox center are fixed and the new h' and w' is h * ratio and w * ratio.
  42. Args:
  43. bbox (Tensor): Bboxes to calculate regions, shape (n, 4).
  44. ratio (float): Ratio of the output region.
  45. featmap_size (tuple, Optional): Feature map size in (height, width)
  46. order used for clipping the boundary. Defaults to None.
  47. Returns:
  48. tuple: x1, y1, x2, y2
  49. """
  50. x1 = torch.round((1 - ratio) * bbox[0] + ratio * bbox[2]).long()
  51. y1 = torch.round((1 - ratio) * bbox[1] + ratio * bbox[3]).long()
  52. x2 = torch.round(ratio * bbox[0] + (1 - ratio) * bbox[2]).long()
  53. y2 = torch.round(ratio * bbox[1] + (1 - ratio) * bbox[3]).long()
  54. if featmap_size is not None:
  55. x1 = x1.clamp(min=0, max=featmap_size[1])
  56. y1 = y1.clamp(min=0, max=featmap_size[0])
  57. x2 = x2.clamp(min=0, max=featmap_size[1])
  58. y2 = y2.clamp(min=0, max=featmap_size[0])
  59. return (x1, y1, x2, y2)