Co-authored-by: Zhuohan Li <zhuohan@openai.com> Co-authored-by: Maratyszcza <marat@openai.com> Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
41 lines
1.2 KiB
Python
41 lines
1.2 KiB
Python
import os
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
|
|
def suppress_output(rank):
|
|
"""Suppress printing on the current device. Force printing with `force=True`."""
|
|
import builtins as __builtin__
|
|
builtin_print = __builtin__.print
|
|
|
|
def print(*args, **kwargs):
|
|
force = kwargs.pop('force', False)
|
|
if force:
|
|
builtin_print("rank #%d:" % rank, *args, **kwargs)
|
|
elif rank == 0:
|
|
builtin_print(*args, **kwargs)
|
|
|
|
__builtin__.print = print
|
|
|
|
|
|
def init_distributed() -> torch.device:
|
|
"""Initialize the model for distributed inference."""
|
|
# Initialize distributed inference
|
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
|
rank = int(os.environ.get("RANK", 0))
|
|
if world_size > 1:
|
|
dist.init_process_group(
|
|
backend="nccl", init_method="env://", world_size=world_size, rank=rank
|
|
)
|
|
torch.cuda.set_device(rank)
|
|
device = torch.device(f"cuda:{rank}")
|
|
|
|
# Warm up NCCL to avoid first-time latency
|
|
if world_size > 1:
|
|
x = torch.ones(1, device=device)
|
|
dist.all_reduce(x)
|
|
torch.cuda.synchronize(device)
|
|
|
|
suppress_output(rank)
|
|
return device
|