PyTorch Distributed Communication Primitives
PyTorch allreduce Distributed Training
PyTorch Distributed Communication Primitives
torch.distributed
: (1) P2P communication (via send
and recv
) (2) collective communication.
P2P Communication
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def init_process(rank_id, size, fn, backend='gloo'):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank_id, world_size=size)
fn(rank_id, size)
def run(rank_id, size):
tensor = torch.zeros(1)
if rank_id == 0:
tensor += 1
dist.send(tensor=tensor, dst=1)
dist.recv(tensor=tensor, src=1)
else:
dist.recv(tensor=tensor, src=0)
tensor += 1
dist.send(tensor=tensor, dst=0)
if __name__ == "__main__":
size = 2
processes = []
mp.set_start_method("spawn")
for rank in range(size):
p = mp.Process(target=init_process, args=(rank, size, run))
p.start()
processes.append(p)
for p in processes:
p.join()
torch.multiprocessing
is a wrapper around the native multiprocessing
module.
Collective Communication
Broadcast
# Broadcast
def run(rank_id, size):
tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank_id
dist.broadcast(tensor, src = 0)
The tensor
in the src node is broadcast to the other nodes.
Scatter
def run(rank_id, size):
tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank_id
if rank_id == 0:
scatter_list = [torch.tensor([0,0]), torch.tensor([1,1]), torch.tensor([2,2]), torch.tensor([3,3])]
dist.scatter(tensor, src = 0, scatter_list=scatter_list)
else:
dist.scatter(tensor, src = 0)
Note that the destination ranks receive the scattered tensors in sequential order; that is, rank i
receives scatter_list[i]
.
Gather
def run(rank_id, size):
tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank_id
if rank_id == 0:
gather_list = [torch.zeros(2, dtype=torch.int64) for _ in range(4)]
dist.gather(tensor, dst = 0, gather_list=gather_list)
else:
dist.gather(tensor, dst = 0)
Reduce
def run(rank_id, size):
tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank_id
dist.reduce(tensor, dst=3, op=dist.ReduceOp.SUM)
PyTorch supports a variety of reduction operations, including SUM, PRODUCT, MIN, MAX, BAND, BOR, and BXOR.
It's important to note that if the rank to perform the reduction operation is not specified, the tensor
in every rank might be altered.
All-gather and all-reduce
def run_all_gather(rank_id, size):
tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank_id
gather_list = [torch.zeros(2, dtype=torch.int64) for _ in range(4)]
dist.all_gather(gather_list, tensor)
def run_all_reduce(rank_id, size):
tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank_id
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
Last updated