Theo75
Theo75

Reputation: 487

Undestanding multiprocessing / Manager Queue in Python

I have difficulties to understand manager queue in multiprocessing with Python.

My teacher gave me this code:

def check_prime(n): test if n is a prime or not and returns a boolean.

def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def find_prime_worker(input, output):
    for chunk in iter(input.get,None):
        print('(Pid: {})'.format(getpid()))
        primes_found = list(filter(check_prime,chunk))
        output.put(primes_found)

from multiprocessing import Pool,Manager
from itertools import chain

def calculate_primes(ncore,N,chunksize):
    with Manager() as manager:
        input = manager.Queue()
        output = manager.Queue()

        with Pool(ncore) as p:
            it = p.starmap_async(find_prime_worker,[(input,output)]*ncore)
            for r in chunks(range(1,N),chunksize):
                input.put(r)
            for i in range(ncore): input.put(None)
            it.wait()
            output.put(None)

        res = list(chain(*list(iter(output.get,None))))
    return res

calculate_primes(8,100,5)
  1. If in find_prime_worker function, I add

    print('(Pid: {})'.format(getpid()))

calculate_primes() function returns an emplty result []. If I remove this line, function works fine. I would like to know what process id is executing without break the correct function... How to do that please?

  1. I don't undestand calculate_primes() function: we define 2 Manager queues input and output.

I don't understand why we pass to find_prime_worker function with starmap_async, argument with size [(input,output)]*ncore.

Thanks a lot.

Upvotes: 0

Views: 98

Answers (1)

Booboo
Booboo

Reputation: 44128

If you want to print the process id, you need:

from os import getpid

Without that statement your worker function, find_prime_worker, is raising an exception but since you are only waiting for the results from invoking your worker function to be ready by calling it.wait() rather than actually retrieving the results with it.get(), you don't get the exception reflected back to the main process. Change your code to return_values = it.get() and you will see the exception. Add the import statement and the exception will go away and return_values will be a list of all the return values from find_prime_worker, which will be [None] * 8 since that function implicitly is returning None and it is being invoked ncore or 8 times. By the way, the name it is poorly choses since it suggests (at least to me) that what is being returned from the call to the starmap_async method is an iterator when in fact what is being returned is a multiprocessing.pool.AsycResult instance. result would be a better name for this return value in my opinion.

Which brings me to your second question. The program is trying to find all the primes from 1 to 99 and does this by creating a process pool with ncore processes where ncore is set to 8. [(input,output)]*ncore creates a list of 8 tuples where each tuple is (input,ouput). Therefore, the statement it = p.starmap_async(find_prime_worker,[(input,output)]*ncore) will create 8 tasks where each task is invoking find_prime_worker with arguments input and output. The tasks are therefore identical in that they read their input by taking messages from the input queue input and write their results to the output queue output. The messages on the input queue are just the numbers 1 through 99 broken up into ranges of size 5: range(1, 6), range(6, 11), range(11, 16) ... range(96, 100).


Extra Information Only If You Are Interested

The code is unusual in that a multiprocessing pool internally uses queues for passing arguments and removes the necessity for explicitly creating queues for that purpose. Yet the code is using the pool's internal queues for passing managed queue arguments that hold values that could have been more easily handled transparently by the pool's argument passing mechanism:

def check_prime(n):
    # Not a real implemntation and not all the values between 1 and 100:
    if n in (3, 5, 7, 11, 13, 17, 19, 23, 97):
        return True
    return False

def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def find_prime_worker(chunk):
    from os import getpid

    print('(Pid: {})'.format(getpid()))
    return list(filter(check_prime, chunk))


def calculate_primes(ncore, N, chunksize):
    from multiprocessing import Pool

    with Pool(ncore) as p:
        results = p.map(find_prime_worker, chunks(range(1, N), chunksize))
    res = []
    for result in results:
        res.extend(result)
    return res

# Required for Windows:
if __name__ == '__main__':
    print(calculate_primes(8, 100, 5))

Prints:

(Pid: 9440)
(Pid: 9440)
(Pid: 9440)
(Pid: 9440)
(Pid: 9440)
(Pid: 9112)
(Pid: 9440)
(Pid: 9112)
(Pid: 9440)
(Pid: 9440)
(Pid: 9112)
(Pid: 14644)
(Pid: 14644)
(Pid: 9112)
(Pid: 9440)
(Pid: 14644)
(Pid: 9112)
(Pid: 9112)
(Pid: 14644)
(Pid: 9440)
[3, 5, 7, 11, 13, 17, 19, 23, 97]

But let's stick with the code's original methodology. If you are interested, the code can be rewritten using the more performant multiprocessing.Queue instead of the managed queue currently being used. It cannot, however, be passed as arguments to multiprocessing pool worker functions. Instead you have to use the initializer and initargs arguments of the multiprocessing.pool.Pool constructor to initialize global variables for each process in the pool with these queue references. And since the worker function will now be referencing the queues as global variables, it no longer takes any arguments and starmap_async or any of the map functions no longer seem appropriate without introducing back to function find_prime_worker a dummy argument that is never used. So method apply_async seems to be the more logical choice:

def init_processes(inq, outq):
    global input, output
    input = inq
    output = outq

def check_prime(n):
    # Not a real implementation and not all the values between 1 and 100:
    if n in (3, 5, 7, 11, 13, 17, 19, 23, 97):
        return True
    return False

def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def find_prime_worker():
    from os import getpid

    for chunk in iter(input.get, None):
        print('(Pid: {})'.format(getpid()))
        primes_found = list(filter(check_prime, chunk))
        output.put(primes_found)

def calculate_primes(ncore, N, chunksize):
    from multiprocessing import Pool, Queue
    from itertools import chain

    input = Queue()
    output = Queue()

    with Pool(ncore, initializer=init_processes, initargs=(input, output)) as p:
        results = [p.apply_async(find_prime_worker) for _ in range(ncore)]
        for r in chunks(range(1, N), chunksize):
            input.put(r)
        for i in range(ncore):
            input.put(None)
        # Wait for all tasks to complete.
        # The actual return values are not of interest;
        # they are just None -- this is just for demo purposes but
        # calling result.get() will raise an exception if find_prime_worker
        # raised an exception:
        return_values = [result.get() for result in results]
        print(return_values)

        output.put(None)
        res = list(chain(*list(iter(output.get,None))))
    return res

# Required for Windows:
if __name__ == '__main__':
    print(calculate_primes(8, 100, 5))

Prints:

(Pid: 10516)
(Pid: 10516)
(Pid: 10516)
(Pid: 10516)
(Pid: 10516)
(Pid: 10516)
(Pid: 10516)
(Pid: 10516)
(Pid: 10516)
(Pid: 10516)
(Pid: 10516)
(Pid: 10516)
(Pid: 10516)
(Pid: 10516)
(Pid: 10516)
(Pid: 10516)
(Pid: 10516)
(Pid: 10516)
(Pid: 10516)
(Pid: 10516)
[None, None, None, None, None, None, None, None]
[3, 5, 7, 11, 13, 17, 19, 23, 97]

There are several things to note here. First, although we are using a pool size of 8 we can see that a single process 10516 processed all the chunks. This was probably due to the fact that check_prime was so trivial that the calculate_primes task was able to read a chunk from the input queue, call check_prime once for each of the 5 numbers in the chunk and get back to reading all the other chunks and repeating the calls before the other processes in the pool even had a chance to run.

Second, there is overhead in creating a multiprocessing pool and in sending arguments and getting results via queues you don't have otherwise. find_prime_worker has to be sufficiently CPU-intensive so that parallel processing compensates for that additional overhead. I doubt in this case it does, at least not with this implementation of find_prime_worker.

Third, and this is very subtle, using a multiprocessing.Queue requires a bit of care. In the above code we know that find_prime_worker does not return anything other than None so the statement return_values = [result.get() for result in results] was mostly for demo purposes and to make sure that any exception raised by find_prime_worker would not go unnoticed. But, in general, if you just want to wait for all submitted tasks to complete and do not care about the return values, instead of saving all the AsyncResult instances returned from the apply_async calls and then calling get on those instances, one can just call p.close() followed by p.join() and after these call are executed you can be sure that all submitted tasks have completed (possibly with an exception). This is because calling join on the pool instance will wait for all pool processes to exit, which will occur only upon completing their work or some exceptional condition. But I did not use that method in this code because you must never join a process that has written to a multiprocessing queue before you have read all the messages it has written to that queue.

But we can change the logic so we know that we have read every possible message from the queue by having each worker task write a None record after it has finished processing all of its chunks:

def init_processes(inq, outq):
    global input, output
    input = inq
    output = outq

def check_prime(n):
    # Not a real implemntation and not all the values between 1 and 100:
    if n in (3, 5, 7, 11, 13, 17, 19, 23, 97):
        return True
    return False

def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def find_prime_worker():
    from os import getpid

    for chunk in iter(input.get, None):
        print('(Pid: {})'.format(getpid()))
        primes_found = list(filter(check_prime, chunk))
        if primes_found:
            output.put(primes_found)
    output.put(None)

def calculate_primes(ncore, N, chunksize):
    from multiprocessing import Pool, Queue

    input = Queue()
    output = Queue()

    with Pool(ncore, initializer=init_processes, initargs=(input, output)) as p:
        results = [p.apply_async(find_prime_worker) for _ in range(ncore)]
        for r in chunks(range(1, N), chunksize):
            input.put(r)
        for i in range(ncore):
            input.put(None)
        # We assume find_prime_worker does not raise and exception and
        # there will be ncore None records on the output queue:
        none_seen = 0
        res = []
        while none_seen < ncore:
            result = output.get()
            if result is None:
                none_seen += 1
            else:
                res.extend(result)
        # Now we can, to be "tidy":
        p.close()
        p.join()
    return res

# Required for Windows:
if __name__ == '__main__':
    print(calculate_primes(8, 100, 5))

Prints:

(Pid: 15796)
(Pid: 14644)
(Pid: 19100)
(Pid: 15796)
(Pid: 15796)
(Pid: 14644)
(Pid: 15796)
(Pid: 19100)
(Pid: 14644)
(Pid: 14644)
(Pid: 19100)
(Pid: 15796)
(Pid: 14644)
(Pid: 8924)
(Pid: 15796)
(Pid: 19100)
(Pid: 14644)
(Pid: 8924)
(Pid: 15796)
(Pid: 19100)
[3, 5, 17, 19, 7, 23, 11, 13, 97]

Note that I added the additional check to find_prime_worker as there is no point in writing empty lists to the output queue:

        if primes_found:
            output.put(primes_found)

And finally, the normal use case for using an input and output queue is when you are not using a multiprocessing pool but rather creating your own multirpocessing.Process instances, in effect using these Process instances and queue to implement your own pool:

def check_prime(n):
    # Not a real implementation and not all the values between 1 and 100:
    if n in (3, 5, 7, 11, 13, 17, 19, 23, 97):
        return True
    return False

def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def find_prime_worker(input, output):
    from os import getpid

    for chunk in iter(input.get, None):
        print('(Pid: {})'.format(getpid()))
        primes_found = list(filter(check_prime, chunk))
        if primes_found:
            output.put(primes_found)
    output.put(None)

def calculate_primes(ncore, N, chunksize):
    from multiprocessing import Process, Queue

    input = Queue()
    output = Queue()

    processes = [Process(target=find_prime_worker, args=(input, output))
                 for _ in range(ncore)
                 ]
    for process in processes:
        process.start()
    for r in chunks(range(1, N), chunksize):
        input.put(r)
    for i in range(ncore):
        input.put(None)

    none_seen = 0
    res = []
    while none_seen < ncore:
        result = output.get()
        if result is None:
            none_seen += 1
        else:
            res.extend(result)

    for process in processes:
        process.join()

    return res

# Required for Windows:
if __name__ == '__main__':
    print(calculate_primes(8, 100, 5))

Prints:

(Pid: 18112)
(Pid: 18112)
(Pid: 18112)
(Pid: 18112)
(Pid: 18112)
(Pid: 18112)
(Pid: 18112)
(Pid: 18112)
(Pid: 18112)
(Pid: 18112)
(Pid: 18112)
(Pid: 18112)
(Pid: 18112)
(Pid: 18112)
(Pid: 18112)
(Pid: 18112)
(Pid: 18112)
(Pid: 18112)
(Pid: 18112)
(Pid: 18112)
[3, 5, 7, 11, 13, 17, 19, 23, 97]

Upvotes: 2

Related Questions