123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List, Tuple, Union
- import mmcv
- import numpy as np
- from mmengine.utils import is_str
- def palette_val(palette: List[tuple]) -> List[tuple]:
- """Convert palette to matplotlib palette.
- Args:
- palette (List[tuple]): A list of color tuples.
- Returns:
- List[tuple[float]]: A list of RGB matplotlib color tuples.
- """
- new_palette = []
- for color in palette:
- color = [c / 255 for c in color]
- new_palette.append(tuple(color))
- return new_palette
- def get_palette(palette: Union[List[tuple], str, tuple],
- num_classes: int) -> List[Tuple[int]]:
- """Get palette from various inputs.
- Args:
- palette (list[tuple] | str | tuple): palette inputs.
- num_classes (int): the number of classes.
- Returns:
- list[tuple[int]]: A list of color tuples.
- """
- assert isinstance(num_classes, int)
- if isinstance(palette, list):
- dataset_palette = palette
- elif isinstance(palette, tuple):
- dataset_palette = [palette] * num_classes
- elif palette == 'random' or palette is None:
- state = np.random.get_state()
- # random color
- np.random.seed(42)
- palette = np.random.randint(0, 256, size=(num_classes, 3))
- np.random.set_state(state)
- dataset_palette = [tuple(c) for c in palette]
- elif palette == 'coco':
- from mmdet.datasets import CocoDataset, CocoPanopticDataset
- dataset_palette = CocoDataset.METAINFO['palette']
- if len(dataset_palette) < num_classes:
- dataset_palette = CocoPanopticDataset.METAINFO['palette']
- elif palette == 'citys':
- from mmdet.datasets import CityscapesDataset
- dataset_palette = CityscapesDataset.METAINFO['palette']
- elif palette == 'voc':
- from mmdet.datasets import VOCDataset
- dataset_palette = VOCDataset.METAINFO['palette']
- elif is_str(palette):
- dataset_palette = [mmcv.color_val(palette)[::-1]] * num_classes
- else:
- raise TypeError(f'Invalid type for palette: {type(palette)}')
- assert len(dataset_palette) >= num_classes, \
- 'The length of palette should not be less than `num_classes`.'
- return dataset_palette
- def _get_adaptive_scales(areas: np.ndarray,
- min_area: int = 800,
- max_area: int = 30000) -> np.ndarray:
- """Get adaptive scales according to areas.
- The scale range is [0.5, 1.0]. When the area is less than
- ``min_area``, the scale is 0.5 while the area is larger than
- ``max_area``, the scale is 1.0.
- Args:
- areas (ndarray): The areas of bboxes or masks with the
- shape of (n, ).
- min_area (int): Lower bound areas for adaptive scales.
- Defaults to 800.
- max_area (int): Upper bound areas for adaptive scales.
- Defaults to 30000.
- Returns:
- ndarray: The adaotive scales with the shape of (n, ).
- """
- scales = 0.5 + (areas - min_area) / (max_area - min_area)
- scales = np.clip(scales, 0.5, 1.0)
- return scales
- def jitter_color(color: tuple) -> tuple:
- """Randomly jitter the given color in order to better distinguish instances
- with the same class.
- Args:
- color (tuple): The RGB color tuple. Each value is between [0, 255].
- Returns:
- tuple: The jittered color tuple.
- """
- jitter = np.random.rand(3)
- jitter = (jitter / np.linalg.norm(jitter) - 0.5) * 0.5 * 255
- color = np.clip(jitter + color, 0, 255).astype(np.uint8)
- return tuple(color)
|