# Copyright (c) OpenMMLab. All rights reserved. from typing import Tuple import torch from torch import Tensor def preprocess_panoptic_gt(gt_labels: Tensor, gt_masks: Tensor, gt_semantic_seg: Tensor, num_things: int, num_stuff: int) -> Tuple[Tensor, Tensor]: """Preprocess the ground truth for a image. Args: gt_labels (Tensor): Ground truth labels of each bbox, with shape (num_gts, ). gt_masks (BitmapMasks): Ground truth masks of each instances of a image, shape (num_gts, h, w). gt_semantic_seg (Tensor | None): Ground truth of semantic segmentation with the shape (1, h, w). [0, num_thing_class - 1] means things, [num_thing_class, num_class-1] means stuff, 255 means VOID. It's None when training instance segmentation. Returns: tuple[Tensor, Tensor]: a tuple containing the following targets. - labels (Tensor): Ground truth class indices for a image, with shape (n, ), n is the sum of number of stuff type and number of instance in a image. - masks (Tensor): Ground truth mask for a image, with shape (n, h, w). Contains stuff and things when training panoptic segmentation, and things only when training instance segmentation. """ num_classes = num_things + num_stuff things_masks = gt_masks.to_tensor( dtype=torch.bool, device=gt_labels.device) if gt_semantic_seg is None: masks = things_masks.long() return gt_labels, masks things_labels = gt_labels gt_semantic_seg = gt_semantic_seg.squeeze(0) semantic_labels = torch.unique( gt_semantic_seg, sorted=False, return_inverse=False, return_counts=False) stuff_masks_list = [] stuff_labels_list = [] for label in semantic_labels: if label < num_things or label >= num_classes: continue stuff_mask = gt_semantic_seg == label stuff_masks_list.append(stuff_mask) stuff_labels_list.append(label) if len(stuff_masks_list) > 0: stuff_masks = torch.stack(stuff_masks_list, dim=0) stuff_labels = torch.stack(stuff_labels_list, dim=0) labels = torch.cat([things_labels, stuff_labels], dim=0) masks = torch.cat([things_masks, stuff_masks], dim=0) else: labels = things_labels masks = things_masks masks = masks.long() return labels, masks