Reputation: 31
i've developed this little POC using pytorch distributed package: essentially a Trainer spawns N processes and orchestrate them using python Pipes (it could also be Queues). Normally it should send data at every epoch, but in this POC the data is just sent once on process creation. The processes train a model through DDP.
import os
import signal
import socket
from contextlib import closing
from multiprocessing.connection import Connection, Pipe
from typing import List
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.nn.parallel import DistributedDataParallel as DDP
def init_process(rank, world_size, ddp_free_port, recv, train_data):
"""Initialize the distributed environment."""
torch.set_num_threads(1)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = ddp_free_port
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["NODE_RANK"] = "0"
dist.init_process_group("gloo", init_method=f"tcp://localhost:{ddp_free_port}", rank=rank, world_size=world_size)
Worker(recv, train_data).train()
class Worker:
def __init__(self, queue, train_dset):
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.queue: Connection = queue
self.train_dset = train_dset
self.model = torch.nn.Sequential(nn.Linear(784, 64), torch.nn.ReLU(), torch.nn.Linear(64, 10))
self.model = DDP(self.model)
self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
def train(self):
loss_fn = nn.CrossEntropyLoss()
sampler = torch.utils.data.distributed.DistributedSampler(
self.train_dset, num_replicas=self.world_size, rank=self.rank, shuffle=True
)
train_loader = torch.utils.data.DataLoader(self.train_dset, sampler=sampler, batch_size=32)
while True:
epoch = self.queue.recv()
if epoch is False:
print(f"Rank-{self.rank} done!")
return
total_loss = 0
sampler.set_epoch(epoch)
for i, batch in enumerate(train_loader):
images, labels = batch
out = self.model(images.view(-1, 28 * 28))
loss = loss_fn(out, labels)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += loss.item()
dist.barrier()
if self.rank == 0:
print(f"Epoch: {epoch}, Loss@rank-{self.rank}: {total_loss / len(train_loader):.4f}")
print(f"Rank-0 is telling the trainer that everything is done for the epoch {epoch}")
self.queue.send(True)
class Trainer:
def __init__(self, world_size: int, epochs: int = 5) -> None:
self.world_size = world_size
self.epochs = epochs
self.train_data = torchvision.datasets.MNIST(
"/tmp/data",
train=True,
download=True,
transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]),
)
self.test_data = torchvision.datasets.MNIST(
"/tmp/data",
train=False,
download=True,
transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]),
)
self.ddp_free_port = str(find_free_port())
def run(self):
"""Run the distributed environment."""
print("Start training")
queues = []
processes = []
for rank in range(self.world_size):
if rank == 0:
recv, send = Pipe(duplex=True)
else:
recv, send = Pipe(duplex=False)
p = mp.Process(
target=init_process,
args=(rank, self.world_size, self.ddp_free_port, recv, self.train_data),
daemon=True,
)
p.start()
queues.append(send)
processes.append(p.pid)
self.train(queues, processes)
def train(self, queues, processes):
for epoch in range(self.epochs):
for rank in range(self.world_size):
queues[rank].send(epoch)
print("Training waiting for rank-0")
queues[0].recv()
for rank in range(self.world_size):
queues[rank].send(False)
queues[rank].close()
os.kill(processes[rank], signal.SIGTERM)
def find_free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
if __name__ == "__main__":
os.environ["LOGLEVEL"] = "DEBUG"
mp.set_start_method("spawn")
trainer = Trainer(world_size=16)
trainer.run()
print("Finished training")
I receive the following error, for every process spawned, randomly, if i increase the number of processes from 16 to 32 for example:
...
Process Process-1:
Traceback (most recent call last):
File "C:\Program Files (x86)\Python38\lib\multiprocessing\process.py", line 315, in _bootstrap
self.run()
File "C:\Program Files (x86)\Python38\lib\multiprocessing\process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "c:\Users\belof\Desktop\temp\examples\ddp_cpu.py", line 27, in init_process
dist.init_process_group("gloo", init_method=f"tcp://localhost:{ddp_free_port}", rank=rank, world_size=world_size)
File "C:\Users\belof\Desktop\temp\.venv\lib\site-packages\torch\distributed\distributed_c10d.py", line 602, in init_process_group
default_pg = _new_process_group_helper(
File "C:\Users\belof\Desktop\temp\.venv\lib\site-packages\torch\distributed\distributed_c10d.py", line 703, in _new_process_group_helper
pg = ProcessGroupGloo(prefix_store, rank, world_size, timeout=timeout)
RuntimeError: Socket Timeout
Traceback (most recent call last):
File "C:\Program Files (x86)\Python38\lib\multiprocessing\connection.py", line 312, in _recv_bytes
nread, err = ov.GetOverlappedResult(True)
BrokenPipeError: [WinError 109] The pipe has been ended
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "c:/Users/belof/Desktop/temp/examples/ddp_cpu.py", line 131, in <module>
trainer.run()
File "c:/Users/belof/Desktop/temp/examples/ddp_cpu.py", line 106, in run
self.train(queues, processes)
File "c:/Users/belof/Desktop/temp/examples/ddp_cpu.py", line 113, in train
queues[0].recv()
File "C:\Program Files (x86)\Python38\lib\multiprocessing\connection.py", line 250, in recv
buf = self._recv_bytes()
File "C:\Program Files (x86)\Python38\lib\multiprocessing\connection.py", line 321, in _recv_bytes
raise EOFError
EOFError
It seems to me something related to windows spawn method and the queue references passed to the processes, but i don't really know what is happening here. This is the result of the collect_env.py script:
Collecting environment information...
PyTorch version: 1.12.1+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: Microsoft Windows 10 Pro
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A
Python version: 3.8.8 (tags/v3.8.8:024d805, Feb 19 2021, 13:18:16) [MSC v.1928 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.19041-SP0
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] mypy==0.931
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.3
[pip3] pytorch-lightning==1.6.4
[pip3] torch==1.12.1
[pip3] torchmetrics==0.9.3
[pip3] torchvision==0.12.0
[conda] Could not collect
Upvotes: 0
Views: 803
Reputation: 31
As a workaround, i've set a huge timeout to the init_process_group
function:
from datetime import timedelta
dist.init_process_group(
"gloo",
init_method=f"tcp://localhost:{ddp_free_port}",
rank=rank,
world_size=world_size,
timeout=timedelta(days=1),
)
This lets me run the script with 64 processes for example
Upvotes: 1