split_batch.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. def split_batch(img, img_metas, kwargs):
  4. """Split data_batch by tags.
  5. Code is modified from
  6. <https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/structure_utils.py> # noqa: E501
  7. Args:
  8. img (Tensor): of shape (N, C, H, W) encoding input images.
  9. Typically these should be mean centered and std scaled.
  10. img_metas (list[dict]): List of image info dict where each dict
  11. has: 'img_shape', 'scale_factor', 'flip', and may also contain
  12. 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
  13. For details on the values of these keys, see
  14. :class:`mmdet.datasets.pipelines.Collect`.
  15. kwargs (dict): Specific to concrete implementation.
  16. Returns:
  17. data_groups (dict): a dict that data_batch splited by tags,
  18. such as 'sup', 'unsup_teacher', and 'unsup_student'.
  19. """
  20. # only stack img in the batch
  21. def fuse_list(obj_list, obj):
  22. return torch.stack(obj_list) if isinstance(obj,
  23. torch.Tensor) else obj_list
  24. # select data with tag from data_batch
  25. def select_group(data_batch, current_tag):
  26. group_flag = [tag == current_tag for tag in data_batch['tag']]
  27. return {
  28. k: fuse_list([vv for vv, gf in zip(v, group_flag) if gf], v)
  29. for k, v in data_batch.items()
  30. }
  31. kwargs.update({'img': img, 'img_metas': img_metas})
  32. kwargs.update({'tag': [meta['tag'] for meta in img_metas]})
  33. tags = list(set(kwargs['tag']))
  34. data_groups = {tag: select_group(kwargs, tag) for tag in tags}
  35. for tag, group in data_groups.items():
  36. group.pop('tag')
  37. return data_groups