PyTorch Distributed Training

PyTorch Distributed Communication Primitives

PyTorch allreduce Distributed Training

PyTorch Distributed Communication Primitives

torch.distributed: (1) P2P communication (via send and recv) (2) collective communication.

P2P Communication

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def init_process(rank_id, size, fn, backend='gloo'):
  os.environ['MASTER_ADDR'] = '127.0.0.1'
  os.environ['MASTER_PORT'] = '29500'
  dist.init_process_group(backend, rank=rank_id, world_size=size)
  fn(rank_id, size)
  
def run(rank_id, size):
  tensor = torch.zeros(1)
  if rank_id == 0:
    tensor += 1
    dist.send(tensor=tensor, dst=1)
    dist.recv(tensor=tensor, src=1)
  else:
    dist.recv(tensor=tensor, src=0)
    tensor += 1
    dist.send(tensor=tensor, dst=0)

if __name__ == "__main__":
  size = 2
  processes = []
  mp.set_start_method("spawn")
  for rank in range(size):
    p = mp.Process(target=init_process, args=(rank, size, run))
    p.start()
    processes.append(p)
  for p in processes:
    p.join()

torch.multiprocessing is a wrapper around the native multiprocessingmodule.

Collective Communication

Broadcast

# Broadcast
def run(rank_id, size):
    tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank_id
    dist.broadcast(tensor, src = 0)

The tensor in the src node is broadcast to the other nodes.

Scatter

def run(rank_id, size):
    tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank_id
    if rank_id == 0:
        scatter_list = [torch.tensor([0,0]), torch.tensor([1,1]), torch.tensor([2,2]), torch.tensor([3,3])]
        dist.scatter(tensor, src = 0, scatter_list=scatter_list)
    else:
        dist.scatter(tensor, src = 0)

Note that the destination ranks receive the scattered tensors in sequential order; that is, rank i receives scatter_list[i].

Gather

def run(rank_id, size):
    tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank_id
    if rank_id == 0:
        gather_list = [torch.zeros(2, dtype=torch.int64) for _ in range(4)]
        dist.gather(tensor, dst = 0, gather_list=gather_list)
    else:
        dist.gather(tensor, dst = 0)

Reduce

def run(rank_id, size):
    tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank_id
    dist.reduce(tensor, dst=3, op=dist.ReduceOp.SUM)

PyTorch supports a variety of reduction operations, including SUM, PRODUCT, MIN, MAX, BAND, BOR, and BXOR.

It's important to note that if the rank to perform the reduction operation is not specified, the tensor in every rank might be altered.

All-gather and all-reduce

def run_all_gather(rank_id, size):
    tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank_id
    gather_list = [torch.zeros(2, dtype=torch.int64) for _ in range(4)]
    dist.all_gather(gather_list, tensor)

def run_all_reduce(rank_id, size):
    tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank_id
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)

Last updated