Rahul
Rahul

Reputation: 3386

How to use PyTorch multiprocessing?

I'm trying to use python's multiprocessing Pool method in pytorch to process a image. Here's the code:

from multiprocessing import Process, Pool
from torch.autograd import Variable
import numpy as np
from scipy.ndimage import zoom

def get_pred(args):

  img = args[0]
  scale = args[1]
  scales = args[2]
  img_scale = zoom(img.numpy(),
                     (1., 1., scale, scale),
                     order=1,
                     prefilter=False,
                     mode='nearest')

  # feed input data
  input_img = Variable(torch.from_numpy(img_scale),
                     volatile=True).cuda()
  return input_img

scales = [1,2,3,4,5]
scale_list = []
for scale in scales: 
    scale_list.append([img,scale,scales])
multi_pool = Pool(processes=5)
predictions = multi_pool.map(get_pred,scale_list)
multi_pool.close() 
multi_pool.join()

I'm getting this error:

`RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

` In this line:

predictions = multi_pool.map(get_pred,scale_list)

Can anyone tell me what I'm doing wrong ?

Upvotes: 40

Views: 96650

Answers (2)

nicobonne
nicobonne

Reputation: 688

As stated in pytorch documentation the best practice to handle multiprocessing is to use torch.multiprocessing instead of multiprocessing.

Be aware that sharing CUDA tensors between processes is supported only in Python 3, either with spawn or forkserver as start method.

Without touching your code, a workaround for the error you got is replacing

from multiprocessing import Process, Pool

with:

from torch.multiprocessing import Pool, Process, set_start_method
try:
     set_start_method('spawn')
except RuntimeError:
    pass

Upvotes: 52

Oliver Baumann
Oliver Baumann

Reputation: 2289

I suggest you read the docs for the multiprocessing module, especially this section. You will have to change the way subprocesses are created by calling set_start_method. Taken from those quoted docs:

import multiprocessing as mp

def foo(q):
    q.put('hello')

if __name__ == '__main__':
    mp.set_start_method('spawn')
    q = mp.Queue()
    p = mp.Process(target=foo, args=(q,))
    p.start()
    print(q.get())
    p.join()

Upvotes: 11

Related Questions