Reputation: 10739
I am trying out distributed training in pytorch using "DistributedDataParallel" strategy on databrick notebooks (or any notebooks environment). But I am stuck with multi-processing on a databricks notebook environment.
Problem: I want to spwan multiple processes on databricks notebook using torch.multiprocessing. I have extracted out the problem from main code for easier understanding of the problem.
import torch.distributed as dist
import torch.multiprocessing as mp
def train():
print("hello")
if __name__ == '__main__':
processes = 4
mp.spawn(train, args=(), nprocs=processes)
print("completed")
Exception:
ProcessExitedException: process 1 terminated with exit code 1
---------------------------------------------------------------------------
ProcessExitedException Traceback (most recent call last)
<command-2917251930623656> in <module>
19 if __name__ == '__main__':
20 processes = 4
---> 21 mp.spawn(train, args=(), nprocs=processes)
22 print("completed")
23
/databricks/python/lib/python3.8/site-packages/torch/multiprocessing/spawn.py in spawn(fn, args, nprocs, join, daemon, start_method)
228 ' torch.multiprocessing.start_processes(...)' % start_method)
229 warnings.warn(msg)
--> 230 return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
/databricks/python/lib/python3.8/site-packages/torch/multiprocessing/spawn.py in start_processes(fn, args, nprocs, join, daemon, start_method)
186
187 # Loop on join until it returns True or raises an exception.
--> 188 while not context.join():
189 pass
190
Upvotes: 2
Views: 832