palette.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Tuple, Union
  3. import mmcv
  4. import numpy as np
  5. from mmengine.utils import is_str
  6. def palette_val(palette: List[tuple]) -> List[tuple]:
  7. """Convert palette to matplotlib palette.
  8. Args:
  9. palette (List[tuple]): A list of color tuples.
  10. Returns:
  11. List[tuple[float]]: A list of RGB matplotlib color tuples.
  12. """
  13. new_palette = []
  14. for color in palette:
  15. color = [c / 255 for c in color]
  16. new_palette.append(tuple(color))
  17. return new_palette
  18. def get_palette(palette: Union[List[tuple], str, tuple],
  19. num_classes: int) -> List[Tuple[int]]:
  20. """Get palette from various inputs.
  21. Args:
  22. palette (list[tuple] | str | tuple): palette inputs.
  23. num_classes (int): the number of classes.
  24. Returns:
  25. list[tuple[int]]: A list of color tuples.
  26. """
  27. assert isinstance(num_classes, int)
  28. if isinstance(palette, list):
  29. dataset_palette = palette
  30. elif isinstance(palette, tuple):
  31. dataset_palette = [palette] * num_classes
  32. elif palette == 'random' or palette is None:
  33. state = np.random.get_state()
  34. # random color
  35. np.random.seed(42)
  36. palette = np.random.randint(0, 256, size=(num_classes, 3))
  37. np.random.set_state(state)
  38. dataset_palette = [tuple(c) for c in palette]
  39. elif palette == 'coco':
  40. from mmdet.datasets import CocoDataset, CocoPanopticDataset
  41. dataset_palette = CocoDataset.METAINFO['palette']
  42. if len(dataset_palette) < num_classes:
  43. dataset_palette = CocoPanopticDataset.METAINFO['palette']
  44. elif palette == 'citys':
  45. from mmdet.datasets import CityscapesDataset
  46. dataset_palette = CityscapesDataset.METAINFO['palette']
  47. elif palette == 'voc':
  48. from mmdet.datasets import VOCDataset
  49. dataset_palette = VOCDataset.METAINFO['palette']
  50. elif is_str(palette):
  51. dataset_palette = [mmcv.color_val(palette)[::-1]] * num_classes
  52. else:
  53. raise TypeError(f'Invalid type for palette: {type(palette)}')
  54. assert len(dataset_palette) >= num_classes, \
  55. 'The length of palette should not be less than `num_classes`.'
  56. return dataset_palette
  57. def _get_adaptive_scales(areas: np.ndarray,
  58. min_area: int = 800,
  59. max_area: int = 30000) -> np.ndarray:
  60. """Get adaptive scales according to areas.
  61. The scale range is [0.5, 1.0]. When the area is less than
  62. ``min_area``, the scale is 0.5 while the area is larger than
  63. ``max_area``, the scale is 1.0.
  64. Args:
  65. areas (ndarray): The areas of bboxes or masks with the
  66. shape of (n, ).
  67. min_area (int): Lower bound areas for adaptive scales.
  68. Defaults to 800.
  69. max_area (int): Upper bound areas for adaptive scales.
  70. Defaults to 30000.
  71. Returns:
  72. ndarray: The adaotive scales with the shape of (n, ).
  73. """
  74. scales = 0.5 + (areas - min_area) / (max_area - min_area)
  75. scales = np.clip(scales, 0.5, 1.0)
  76. return scales
  77. def jitter_color(color: tuple) -> tuple:
  78. """Randomly jitter the given color in order to better distinguish instances
  79. with the same class.
  80. Args:
  81. color (tuple): The RGB color tuple. Each value is between [0, 255].
  82. Returns:
  83. tuple: The jittered color tuple.
  84. """
  85. jitter = np.random.rand(3)
  86. jitter = (jitter / np.linalg.norm(jitter) - 0.5) * 0.5 * 255
  87. color = np.clip(jitter + color, 0, 255).astype(np.uint8)
  88. return tuple(color)