instaboost.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Tuple
  3. import numpy as np
  4. from mmcv.transforms import BaseTransform
  5. from mmdet.registry import TRANSFORMS
  6. @TRANSFORMS.register_module()
  7. class InstaBoost(BaseTransform):
  8. r"""Data augmentation method in `InstaBoost: Boosting Instance
  9. Segmentation Via Probability Map Guided Copy-Pasting
  10. <https://arxiv.org/abs/1908.07801>`_.
  11. Refer to https://github.com/GothicAi/Instaboost for implementation details.
  12. Required Keys:
  13. - img (np.uint8)
  14. - instances
  15. Modified Keys:
  16. - img (np.uint8)
  17. - instances
  18. Args:
  19. action_candidate (tuple): Action candidates. "normal", "horizontal", \
  20. "vertical", "skip" are supported. Defaults to ('normal', \
  21. 'horizontal', 'skip').
  22. action_prob (tuple): Corresponding action probabilities. Should be \
  23. the same length as action_candidate. Defaults to (1, 0, 0).
  24. scale (tuple): (min scale, max scale). Defaults to (0.8, 1.2).
  25. dx (int): The maximum x-axis shift will be (instance width) / dx.
  26. Defaults to 15.
  27. dy (int): The maximum y-axis shift will be (instance height) / dy.
  28. Defaults to 15.
  29. theta (tuple): (min rotation degree, max rotation degree). \
  30. Defaults to (-1, 1).
  31. color_prob (float): Probability of images for color augmentation.
  32. Defaults to 0.5.
  33. hflag (bool): Whether to use heatmap guided. Defaults to False.
  34. aug_ratio (float): Probability of applying this transformation. \
  35. Defaults to 0.5.
  36. """
  37. def __init__(self,
  38. action_candidate: tuple = ('normal', 'horizontal', 'skip'),
  39. action_prob: tuple = (1, 0, 0),
  40. scale: tuple = (0.8, 1.2),
  41. dx: int = 15,
  42. dy: int = 15,
  43. theta: tuple = (-1, 1),
  44. color_prob: float = 0.5,
  45. hflag: bool = False,
  46. aug_ratio: float = 0.5) -> None:
  47. import matplotlib
  48. import matplotlib.pyplot as plt
  49. default_backend = plt.get_backend()
  50. try:
  51. import instaboostfast as instaboost
  52. except ImportError:
  53. raise ImportError(
  54. 'Please run "pip install instaboostfast" '
  55. 'to install instaboostfast first for instaboost augmentation.')
  56. # instaboost will modify the default backend
  57. # and cause visualization to fail.
  58. matplotlib.use(default_backend)
  59. self.cfg = instaboost.InstaBoostConfig(action_candidate, action_prob,
  60. scale, dx, dy, theta,
  61. color_prob, hflag)
  62. self.aug_ratio = aug_ratio
  63. def _load_anns(self, results: dict) -> Tuple[list, list]:
  64. """Convert raw anns to instaboost expected input format."""
  65. anns = []
  66. ignore_anns = []
  67. for instance in results['instances']:
  68. label = instance['bbox_label']
  69. bbox = instance['bbox']
  70. mask = instance['mask']
  71. x1, y1, x2, y2 = bbox
  72. # assert (x2 - x1) >= 1 and (y2 - y1) >= 1
  73. bbox = [x1, y1, x2 - x1, y2 - y1]
  74. if instance['ignore_flag'] == 0:
  75. anns.append({
  76. 'category_id': label,
  77. 'segmentation': mask,
  78. 'bbox': bbox
  79. })
  80. else:
  81. # Ignore instances without data augmentation
  82. ignore_anns.append(instance)
  83. return anns, ignore_anns
  84. def _parse_anns(self, results: dict, anns: list, ignore_anns: list,
  85. img: np.ndarray) -> dict:
  86. """Restore the result of instaboost processing to the original anns
  87. format."""
  88. instances = []
  89. for ann in anns:
  90. x1, y1, w, h = ann['bbox']
  91. # TODO: more essential bug need to be fixed in instaboost
  92. if w <= 0 or h <= 0:
  93. continue
  94. bbox = [x1, y1, x1 + w, y1 + h]
  95. instances.append(
  96. dict(
  97. bbox=bbox,
  98. bbox_label=ann['category_id'],
  99. mask=ann['segmentation'],
  100. ignore_flag=0))
  101. instances.extend(ignore_anns)
  102. results['img'] = img
  103. results['instances'] = instances
  104. return results
  105. def transform(self, results) -> dict:
  106. """The transform function."""
  107. img = results['img']
  108. ori_type = img.dtype
  109. if 'instances' not in results or len(results['instances']) == 0:
  110. return results
  111. anns, ignore_anns = self._load_anns(results)
  112. if np.random.choice([0, 1], p=[1 - self.aug_ratio, self.aug_ratio]):
  113. try:
  114. import instaboostfast as instaboost
  115. except ImportError:
  116. raise ImportError('Please run "pip install instaboostfast" '
  117. 'to install instaboostfast first.')
  118. anns, img = instaboost.get_new_data(
  119. anns, img.astype(np.uint8), self.cfg, background=None)
  120. results = self._parse_anns(results, anns, ignore_anns,
  121. img.astype(ori_type))
  122. return results
  123. def __repr__(self) -> str:
  124. repr_str = self.__class__.__name__
  125. repr_str += f'(aug_ratio={self.aug_ratio})'
  126. return repr_str