# PyTorch Distributed Training

[PyTorch Distributed Communication Primitives](https://zhuanlan.zhihu.com/p/478953028)

[PyTorch allreduce Distributed Training](https://zhuanlan.zhihu.com/p/482557067)

## PyTorch Distributed Communication Primitives

`torch.distributed`: (1) P2P communication (via `send` and `recv`) (2) collective communication.

### P2P Communication

```python
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def init_process(rank_id, size, fn, backend='gloo'):
  os.environ['MASTER_ADDR'] = '127.0.0.1'
  os.environ['MASTER_PORT'] = '29500'
  dist.init_process_group(backend, rank=rank_id, world_size=size)
  fn(rank_id, size)
  
def run(rank_id, size):
  tensor = torch.zeros(1)
  if rank_id == 0:
    tensor += 1
    dist.send(tensor=tensor, dst=1)
    dist.recv(tensor=tensor, src=1)
  else:
    dist.recv(tensor=tensor, src=0)
    tensor += 1
    dist.send(tensor=tensor, dst=0)

if __name__ == "__main__":
  size = 2
  processes = []
  mp.set_start_method("spawn")
  for rank in range(size):
    p = mp.Process(target=init_process, args=(rank, size, run))
    p.start()
    processes.append(p)
  for p in processes:
    p.join()
```

`torch.multiprocessing` is a wrapper around the native `multiprocessing`module.

### Collective Communication

**Broadcast**

```python
# Broadcast
def run(rank_id, size):
    tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank_id
    dist.broadcast(tensor, src = 0)
```

The `tensor` in the src node is broadcast to the other nodes.

**Scatter**

```python
def run(rank_id, size):
    tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank_id
    if rank_id == 0:
        scatter_list = [torch.tensor([0,0]), torch.tensor([1,1]), torch.tensor([2,2]), torch.tensor([3,3])]
        dist.scatter(tensor, src = 0, scatter_list=scatter_list)
    else:
        dist.scatter(tensor, src = 0)
```

Note that the destination ranks receive the scattered tensors in sequential order; that is, rank `i` receives `scatter_list[i]`.

**Gather**

```python
def run(rank_id, size):
    tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank_id
    if rank_id == 0:
        gather_list = [torch.zeros(2, dtype=torch.int64) for _ in range(4)]
        dist.gather(tensor, dst = 0, gather_list=gather_list)
    else:
        dist.gather(tensor, dst = 0)
```

**Reduce**

```python
def run(rank_id, size):
    tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank_id
    dist.reduce(tensor, dst=3, op=dist.ReduceOp.SUM)
```

PyTorch supports a variety of reduction operations, including SUM, PRODUCT, MIN, MAX, BAND, BOR, and BXOR.

It's important to note that if the rank to perform the reduction operation is not specified, the `tensor` in every rank **might** be altered.

**All-gather and all-reduce**

```python
def run_all_gather(rank_id, size):
    tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank_id
    gather_list = [torch.zeros(2, dtype=torch.int64) for _ in range(4)]
    dist.all_gather(gather_list, tensor)

def run_all_reduce(rank_id, size):
    tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank_id
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
```


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://osh.fducslg.com/notes/distributed-ml/01-pytorch-distributed-training.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
