Tengis
Tengis

Reputation: 2809

Difference between map() and pool.map()

I have a code like this

def plotFrame(n):
    a = data[n, :]
    do_something_with(a)

data = loadtxt(filename)
ids = data[:,0]  # some numbers from the first column of data
map(plotFrame, ids)

That worked fine for me. Now I want to try replacing map() with pool.map() as follows:

pools = multiprocessing.Pool(processes=1)
pools.map(plotFrame, ids)

But that won't work, saying:

NameError: global name 'data' is not defined

The questions is: What is going on? Why map() does not complain about the data variable that is not passed to the function, but pool.map() does?

EDIT: I' m using Linux.

EDIT 2: Based on @Bill 's second suggestion, I now have the following code:

def plotFrame_v2(line):
    plot_with(line)

if __name__ == "__main__":
    ff = np.loadtxt(filename)
    m = int( max(ff[:,-1]) ) # max id
    l = ff.shape[0]
    nfig = 0
    pool = Pool(processes=1)
    for i in range(0, l/m, 50):
        data = ff[i*m:(i+1)*m, :] # data of one frame contains several ids
        pool.map(plotFrame_v2, data)
        nfig += 1        
        plt.savefig("figs_bot/%.3d.png"%nfig) 
        plt.clf() 

That works just as expected. However, now I have another unexpected problem: The produced figures are blank, whereas the above code with map() produces figures with the content of data.

Upvotes: 3

Views: 1706

Answers (2)

jfs
jfs

Reputation: 414255

To avoid "unexpected" problems, avoid globals.

To reproduce your first code example with builtin map that calls plotFrame:

def plotFrame(n):
    a = data[n, :]
    do_something_with(a)

using multiprocessing.Pool.map, the first thing is to deal with the global data. If do_something_with(a) also uses some global data then it should also be changed.

To see how to pass a numpy array to a child process, see Use numpy array in shared memory for multiprocessing. If you don't need to modify the array then it is even simpler:

import numpy as np

def init(data_): # inherit data
    global data #NOTE: no other globals in the program
    data = data_

def main():
    data = np.loadtxt(filename) 
    ids = data[:,0]  # some numbers from the first column of data
    pool = Pool(initializer=init, initargs=[data])
    pool.map(plotFrame, ids)

if __name__=="__main__":
    main()

All arguments either should be explicitly passed as arguments to plotFrame or inherited via init().

Your second code example tries to manipulate global data again (via plt calls):

import matplotlib.pyplot as plt

#XXX BROKEN, DO NOT USE
pool.map(plotFrame_v2, data)
nfig += 1        
plt.savefig("figs_bot/%.3d.png"%nfig) 
plt.clf()

Unless you draw something in the main process this code saves blank figures. Either plot in the child processes or send data to be plotted to the parent processes explicitly e.g., by returning it from plotFrame and using pool.map() returned value. Here's a code example: how to plot in child processes.

Upvotes: 2

wflynny
wflynny

Reputation: 18521

Using multiprocessing.pool, you are spawning individual processes to work with the shared (global) resource data. Typically, you can allow the processes to work with a shared resource in the parent process by make that resource explicitly global. However, it is better practice to explicitly pass all needed resources to the child processes as function arguments. This is required if you are working on Windows. Check out the multiprocessing guidelines here.

So you could try doing

data = loadtxt(filename)

def plotFrame(n):
    global data
    a = data[n, :]
    do_something_with(a)

ids = data[:,0]  # some numbers from the first column of data
pools = multiprocessing.Pool(processes=1)
pools.map(plotFrame, ids)

or even better see this thread about feeding multiple arguments to a function with multiprocessing.pool. A simple way could be

def plotFrameWrapper(args):
    return plotFrame(*args)

def plotFrame(n, data):
    a = data[n, :]
    do_something_with(a)

if __name__ == "__main__":
    from multiprocessing import Pool
    data = loadtxt(filename)
    pools = Pool(1)

    ids = data[:,0]
    pools.map(plotFrameWrapper, zip([data]*len(inds), inds))
    print results

One last thing: since it looks like the only thing you are doing from your example is slicing the array, you can simply slice first then pass the sliced arrays to your function:

def plotFrame(sliced_data):
    do_something_with(sliced_data)

if __name__ == "__main__":
    from multiprocessing import Pool
    data = loadtxt(filename)
    pools = Pool(1)

    ids = data[:,0]
    pools.map(plotFrame, data[ids])
    print results

Upvotes: 4

Related Questions