Daniel F
Daniel F

Reputation: 14399

`multiprocessing.Pool` in stand-alone functions

Re-asking this with a more specific question as my last version got no responses.

I'm trying to make an importable function to do Stationary Wavelet Transforms of a dataframe with a series of (long) time histories. The actual theory isn't important (and I'm perhaps not even using it quite right), the important part is I'm breaking up the time history into blocks and feeding them to multiple threads using multiprocessing.Pool.

import pandas as pd
import numpy as np
import pywt
from multiprocessing import Pool
import functools

def swt_block(arr, level = 8, wvlt = 'haar'):
    block_length = arr.shape[0]
    if block_length == 2**level:
        d = pywt.swt(arr, wvlt, axis = 0)
    elif block_length < 2**level:
        arr_ = np.pad(arr, 
                      ((0, 2**level - block_length), (0,0)), 
                      'constant', constant_values = 0)
        d = pywt.swt(arr_, wvlt, axis = 0)
    else:
        raise ValueError('block of length ' + str(arr.shape[0]) + ' too large for swt of level ' + str(level))
    out = []
    for lvl in d:
        for coeff in lvl:
            out.append(coeff)
    return np.concatenate(out, axis = -1)[:block_length]


def swt(df, wvlt = 'haar', level = 8, processors = 4):
    block_length = 2**level
    with Pool(processors) as p:
        data = p.map(functools.partial(swt_block, level = level, wvlt = wvlt), 
                     [i.values for _, i in df.groupby(np.arange(len(df)) // block_length)])
    data = np.concatenate(data, axis = 0) 
    header = pd.MultiIndex.from_product([list(range(level)),
                                     [0, 1],
                                     df.columns], 
                                     names = ['level', 'coef', 'channel'])
    df_out = pd.DataFrame(data, index = df.index, columns = header)

    return df_out

I have done this in a stand-alone script previously so the code works if the second function is instead just bare code wrapped in if __name__ == '__main__':, and indeed works in-script if I add a similar block to the end of the script. But if I import or even just run the above in an interpreter and then do

df_swt = swt(df)

Things hang indefinitely. I'm sure it's some sort of guard rail on multiprocessing to prevent me from doing something dumb with threads, but I'd really prefer not to have to copy this block of code into a bunch of other scripts. Including other tags in case they're the culprits somehow.

Upvotes: 1

Views: 210

Answers (1)

Tomerikoo
Tomerikoo

Reputation: 19431

First of all just to be clear, you are creating multiple processes and not threads. If you are specifically interested in threads, change your import to: from multiprocessing.dummy import Pool.

From the multiprocessing introduction:

multiprocessing is a package that supports spawning processes using an API similar to the threading module.

From the multprocessing.dummy section:

multiprocessing.dummy replicates the API of multiprocessing but is no more than a wrapper around the threading module.

Now, I was able to recreate your issue (according to your previous linked question) and indeed the same happened. Running on an interactive shell things simply hanged.

However, interestingly enough, running through the windows cmd, an endless chain of this error appeared on screen:

RuntimeError:
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.

        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:

            if __name__ == '__main__':
                freeze_support()
                ...

        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.

So, as a wild guess, I added to the importing module:

if __name__ == "__main__":

And......... it worked!

Just to clear the doubt I will post here the exact files I used so you can (hopefully) recreate the solution...

In multi.py:

from multiprocessing import Pool

def __foo(x):
    return x**2

def bar(list_of_inputs):
    with Pool() as p:
        out = p.map(__foo, list_of_inputs)
    print(out)

if __name__ == "__main__":
    bar(list(range(50)))

In tests.py:

from multi import bar

l = list(range(50))

if __name__ == "__main__":
    bar(l)

Output when running any of those 2 files (both in shell and through cmd):

[0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529, 576, 625, 676, 729, 784, 841, 900, 961, 1024, 1089, 1156, 1225, 1296, 1369, 1444, 1521, 1600, 1681, 1764, 1849, 1936, 2025, 2116, 2209, 2304, 2401]

UPDATE: I couldn't find any concrete evidence in the docs as to why this issue happens, BUT, apparently it has something to do with creating new processes and the importing of the main module.

As discussed in the start of this answer, it seems you meant to use threads in your intensions and didn't know you were using processes. If that is indeed the case, then using actual threads will solve your problem and will not require you to change anything except the import statement (change to: from multiprocessing.dummy import Pool). With threads you have no restriction on defining the if __name__ == "__main__": neither in the main module nor the importing one. So that should work:

In multi.py:

from multiprocessing.dummy import Pool

def __foo(x):
    return x**2

def bar(list_of_inputs):
    with Pool() as p:
        out = p.map(__foo, list_of_inputs)
    print(out)

if __name__ == "__main__":
    bar(list(range(50)))

In tests.py:

from multi import bar

l = list(range(50))

bar(l)

I really hope this helps you solve your issue, please let me know if it does.

Upvotes: 1

Related Questions