test_palette.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. from mmdet.datasets import CocoDataset
  4. from mmdet.visualization import get_palette, jitter_color, palette_val
  5. def test_palette():
  6. assert palette_val([(1, 2, 3)])[0] == (1 / 255, 2 / 255, 3 / 255)
  7. # test list
  8. palette = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]
  9. palette_ = get_palette(palette, 3)
  10. for color, color_ in zip(palette, palette_):
  11. assert color == color_
  12. # test tuple
  13. palette = get_palette((1, 2, 3), 3)
  14. assert len(palette) == 3
  15. for color in palette:
  16. assert color == (1, 2, 3)
  17. # test color str
  18. palette = get_palette('red', 3)
  19. assert len(palette) == 3
  20. for color in palette:
  21. assert color == (255, 0, 0)
  22. # test dataset str
  23. palette = get_palette('coco', len(CocoDataset.METAINFO['classes']))
  24. assert len(palette) == len(CocoDataset.METAINFO['classes'])
  25. assert palette[0] == (220, 20, 60)
  26. # TODO: Awaiting refactoring
  27. # palette = get_palette('coco', len(CocoPanopticDataset.METAINFO['CLASSES'])) # noqa
  28. # assert len(palette) == len(CocoPanopticDataset.METAINFO['CLASSES'])
  29. # assert palette[-1] == (250, 141, 255)
  30. # palette = get_palette('voc', len(VOCDataset.METAINFO['CLASSES']))
  31. # assert len(palette) == len(VOCDataset.METAINFO['CLASSES'])
  32. # assert palette[0] == (106, 0, 228)
  33. # palette = get_palette('citys', len(CityscapesDataset.METAINFO['CLASSES'])) # noqa
  34. # assert len(palette) == len(CityscapesDataset.METAINFO['CLASSES'])
  35. # assert palette[0] == (220, 20, 60)
  36. # test random
  37. palette1 = get_palette('random', 3)
  38. palette2 = get_palette(None, 3)
  39. for color1, color2 in zip(palette1, palette2):
  40. assert isinstance(color1, tuple)
  41. assert isinstance(color2, tuple)
  42. assert color1 == color2
  43. def test_jitter_color():
  44. color = tuple(np.random.randint(0, 255, 3, np.uint8))
  45. jittered_color = jitter_color(color)
  46. for c in jittered_color:
  47. assert 0 <= c <= 255