dist_utils.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import functools
  3. import pickle
  4. import warnings
  5. from collections import OrderedDict
  6. import numpy as np
  7. import torch
  8. import torch.distributed as dist
  9. from mmengine.dist import get_dist_info
  10. from torch._utils import (_flatten_dense_tensors, _take_tensors,
  11. _unflatten_dense_tensors)
  12. def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
  13. if bucket_size_mb > 0:
  14. bucket_size_bytes = bucket_size_mb * 1024 * 1024
  15. buckets = _take_tensors(tensors, bucket_size_bytes)
  16. else:
  17. buckets = OrderedDict()
  18. for tensor in tensors:
  19. tp = tensor.type()
  20. if tp not in buckets:
  21. buckets[tp] = []
  22. buckets[tp].append(tensor)
  23. buckets = buckets.values()
  24. for bucket in buckets:
  25. flat_tensors = _flatten_dense_tensors(bucket)
  26. dist.all_reduce(flat_tensors)
  27. flat_tensors.div_(world_size)
  28. for tensor, synced in zip(
  29. bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
  30. tensor.copy_(synced)
  31. def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
  32. """Allreduce gradients.
  33. Args:
  34. params (list[torch.Parameters]): List of parameters of a model
  35. coalesce (bool, optional): Whether allreduce parameters as a whole.
  36. Defaults to True.
  37. bucket_size_mb (int, optional): Size of bucket, the unit is MB.
  38. Defaults to -1.
  39. """
  40. grads = [
  41. param.grad.data for param in params
  42. if param.requires_grad and param.grad is not None
  43. ]
  44. world_size = dist.get_world_size()
  45. if coalesce:
  46. _allreduce_coalesced(grads, world_size, bucket_size_mb)
  47. else:
  48. for tensor in grads:
  49. dist.all_reduce(tensor.div_(world_size))
  50. def reduce_mean(tensor):
  51. """"Obtain the mean of tensor on different GPUs."""
  52. if not (dist.is_available() and dist.is_initialized()):
  53. return tensor
  54. tensor = tensor.clone()
  55. dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
  56. return tensor
  57. def obj2tensor(pyobj, device='cuda'):
  58. """Serialize picklable python object to tensor."""
  59. storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj))
  60. return torch.ByteTensor(storage).to(device=device)
  61. def tensor2obj(tensor):
  62. """Deserialize tensor to picklable python object."""
  63. return pickle.loads(tensor.cpu().numpy().tobytes())
  64. @functools.lru_cache()
  65. def _get_global_gloo_group():
  66. """Return a process group based on gloo backend, containing all the ranks
  67. The result is cached."""
  68. if dist.get_backend() == 'nccl':
  69. return dist.new_group(backend='gloo')
  70. else:
  71. return dist.group.WORLD
  72. def all_reduce_dict(py_dict, op='sum', group=None, to_float=True):
  73. """Apply all reduce function for python dict object.
  74. The code is modified from https://github.com/Megvii-
  75. BaseDetection/YOLOX/blob/main/yolox/utils/allreduce_norm.py.
  76. NOTE: make sure that py_dict in different ranks has the same keys and
  77. the values should be in the same shape. Currently only supports
  78. nccl backend.
  79. Args:
  80. py_dict (dict): Dict to be applied all reduce op.
  81. op (str): Operator, could be 'sum' or 'mean'. Default: 'sum'
  82. group (:obj:`torch.distributed.group`, optional): Distributed group,
  83. Default: None.
  84. to_float (bool): Whether to convert all values of dict to float.
  85. Default: True.
  86. Returns:
  87. OrderedDict: reduced python dict object.
  88. """
  89. warnings.warn(
  90. 'group` is deprecated. Currently only supports NCCL backend.')
  91. _, world_size = get_dist_info()
  92. if world_size == 1:
  93. return py_dict
  94. # all reduce logic across different devices.
  95. py_key = list(py_dict.keys())
  96. if not isinstance(py_dict, OrderedDict):
  97. py_key_tensor = obj2tensor(py_key)
  98. dist.broadcast(py_key_tensor, src=0)
  99. py_key = tensor2obj(py_key_tensor)
  100. tensor_shapes = [py_dict[k].shape for k in py_key]
  101. tensor_numels = [py_dict[k].numel() for k in py_key]
  102. if to_float:
  103. warnings.warn('Note: the "to_float" is True, you need to '
  104. 'ensure that the behavior is reasonable.')
  105. flatten_tensor = torch.cat(
  106. [py_dict[k].flatten().float() for k in py_key])
  107. else:
  108. flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key])
  109. dist.all_reduce(flatten_tensor, op=dist.ReduceOp.SUM)
  110. if op == 'mean':
  111. flatten_tensor /= world_size
  112. split_tensors = [
  113. x.reshape(shape) for x, shape in zip(
  114. torch.split(flatten_tensor, tensor_numels), tensor_shapes)
  115. ]
  116. out_dict = {k: v for k, v in zip(py_key, split_tensors)}
  117. if isinstance(py_dict, OrderedDict):
  118. out_dict = OrderedDict(out_dict)
  119. return out_dict
  120. def sync_random_seed(seed=None, device='cuda'):
  121. """Make sure different ranks share the same seed.
  122. All workers must call this function, otherwise it will deadlock.
  123. This method is generally used in `DistributedSampler`,
  124. because the seed should be identical across all processes
  125. in the distributed group.
  126. In distributed sampling, different ranks should sample non-overlapped
  127. data in the dataset. Therefore, this function is used to make sure that
  128. each rank shuffles the data indices in the same order based
  129. on the same seed. Then different ranks could use different indices
  130. to select non-overlapped data from the same data list.
  131. Args:
  132. seed (int, Optional): The seed. Default to None.
  133. device (str): The device where the seed will be put on.
  134. Default to 'cuda'.
  135. Returns:
  136. int: Seed to be used.
  137. """
  138. if seed is None:
  139. seed = np.random.randint(2**31)
  140. assert isinstance(seed, int)
  141. rank, world_size = get_dist_info()
  142. if world_size == 1:
  143. return seed
  144. if rank == 0:
  145. random_num = torch.tensor(seed, dtype=torch.int32, device=device)
  146. else:
  147. random_num = torch.tensor(0, dtype=torch.int32, device=device)
  148. dist.broadcast(random_num, src=0)
  149. return random_num.item()