panoptic_gt_processing.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Tuple
  3. import torch
  4. from torch import Tensor
  5. def preprocess_panoptic_gt(gt_labels: Tensor, gt_masks: Tensor,
  6. gt_semantic_seg: Tensor, num_things: int,
  7. num_stuff: int) -> Tuple[Tensor, Tensor]:
  8. """Preprocess the ground truth for a image.
  9. Args:
  10. gt_labels (Tensor): Ground truth labels of each bbox,
  11. with shape (num_gts, ).
  12. gt_masks (BitmapMasks): Ground truth masks of each instances
  13. of a image, shape (num_gts, h, w).
  14. gt_semantic_seg (Tensor | None): Ground truth of semantic
  15. segmentation with the shape (1, h, w).
  16. [0, num_thing_class - 1] means things,
  17. [num_thing_class, num_class-1] means stuff,
  18. 255 means VOID. It's None when training instance segmentation.
  19. Returns:
  20. tuple[Tensor, Tensor]: a tuple containing the following targets.
  21. - labels (Tensor): Ground truth class indices for a
  22. image, with shape (n, ), n is the sum of number
  23. of stuff type and number of instance in a image.
  24. - masks (Tensor): Ground truth mask for a image, with
  25. shape (n, h, w). Contains stuff and things when training
  26. panoptic segmentation, and things only when training
  27. instance segmentation.
  28. """
  29. num_classes = num_things + num_stuff
  30. things_masks = gt_masks.to_tensor(
  31. dtype=torch.bool, device=gt_labels.device)
  32. if gt_semantic_seg is None:
  33. masks = things_masks.long()
  34. return gt_labels, masks
  35. things_labels = gt_labels
  36. gt_semantic_seg = gt_semantic_seg.squeeze(0)
  37. semantic_labels = torch.unique(
  38. gt_semantic_seg,
  39. sorted=False,
  40. return_inverse=False,
  41. return_counts=False)
  42. stuff_masks_list = []
  43. stuff_labels_list = []
  44. for label in semantic_labels:
  45. if label < num_things or label >= num_classes:
  46. continue
  47. stuff_mask = gt_semantic_seg == label
  48. stuff_masks_list.append(stuff_mask)
  49. stuff_labels_list.append(label)
  50. if len(stuff_masks_list) > 0:
  51. stuff_masks = torch.stack(stuff_masks_list, dim=0)
  52. stuff_labels = torch.stack(stuff_labels_list, dim=0)
  53. labels = torch.cat([things_labels, stuff_labels], dim=0)
  54. masks = torch.cat([things_masks, stuff_masks], dim=0)
  55. else:
  56. labels = things_labels
  57. masks = things_masks
  58. masks = masks.long()
  59. return labels, masks