PyTorch Distributed Training

PyTorch Distributed Communication Primitives

PyTorch allreduce Distributed Training

PyTorch Distributed Communication Primitives

torch.distributed: (1) P2P communication (via send and recv) (2) collective communication.

P2P Communication

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def init_process(rank_id, size, fn, backend='gloo'):
  os.environ['MASTER_ADDR'] = '127.0.0.1'
  os.environ['MASTER_PORT'] = '29500'
  dist.init_process_group(backend, rank=rank_id, world_size=size)
  fn(rank_id, size)
  
def run(rank_id, size):
  tensor = torch.zeros(1)
  if rank_id == 0:
    tensor += 1
    dist.send(tensor=tensor, dst=1)
    dist.recv(tensor=tensor, src=1)
  else:
    dist.recv(tensor=tensor, src=0)
    tensor += 1
    dist.send(tensor=tensor, dst=0)

if __name__ == "__main__":
  size = 2
  processes = []
  mp.set_start_method("spawn")
  for rank in range(size):
    p = mp.Process(target=init_process, args=(rank, size, run))
    p.start()
    processes.append(p)
  for p in processes:
    p.join()

torch.multiprocessing is a wrapper around the native multiprocessingmodule.

Collective Communication

Broadcast

The tensor in the src node is broadcast to the other nodes.

Scatter

Note that the destination ranks receive the scattered tensors in sequential order; that is, rank i receives scatter_list[i].

Gather

Reduce

PyTorch supports a variety of reduction operations, including SUM, PRODUCT, MIN, MAX, BAND, BOR, and BXOR.

It's important to note that if the rank to perform the reduction operation is not specified, the tensor in every rank might be altered.

All-gather and all-reduce

Last updated