Dariyoush
Dariyoush

Reputation: 508

Why serial code is faster than concurrent.futures in this case?

I am using the following code to process some pictures for my ML project and I would like to parallelize it.

import multiprocessing as mp
import concurrent.futures

def track_ids(seq):
    '''The func is so big I can not put it here'''
    ood = {}
    for i in seq:
        # I load around 500 images and process them
        ood[i] = some Value
    return ood

seqs = []
for seq in range(1, 10):# len(seqs)+1):
    seq = txt+str(seq)
    seqs.append(seq)
    # serial call of the function
    track_ids(seq)

#parallel call of the function
with concurrent.futures.ProcessPoolExecutor(max_workers=mp.cpu_count()) as ex:
    ood_id = ex.map(track_ids, seqs)

if I run the code serially it takes 3.0 minutes but for parallel with concurrent, it takes 3.5 minutes. can someone please explain why is that? and present a way to solve the problem.

btw, I have 12 cores. Thanks

Upvotes: 0

Views: 268

Answers (1)

Aaron
Aaron

Reputation: 11075

Here's a brief example of how one might go about profiling multiprocessing code vs serial execution:

from multiprocessing import Pool
from cProfile import Profile
from pstats import Stats
import concurrent.futures

def track_ids(seq):
    '''The func is so big I can not put it here'''
    ood = {}
    for i in seq:
        # I load around 500 images and process them
        ood[i] = some Value
    return ood

def profile_seq():
    p = Profile() #one and only profiler instance
    p.enable()
    seqs = []
    for seq in range(1, 10):# len(seqs)+1):
        seq = txt+str(seq)
        seqs.append(seq)
        # serial call of the function
        track_ids(seq)
    p.disable()
    return Stats(p), seqs


def track_ids_pr(seq):
    p = Profile() #profile the child tasks
    p.enable()
    
    retval = track_ids(seq)
    
    p.disable()
    return (Stats(p, stream="dummy"), retval)
    
def profile_parallel():
    p = Profile() #profile stuff in the main process
    p.enable()
    
    with concurrent.futures.ProcessPoolExecutor(max_workers=mp.cpu_count()) as ex:
        retvals = ex.map(track_ids_pr, seqs)
        
    p.disable()
    s = Stats(p)
    
    out = []
    for ret in retvals:
        s.add(ret[0])
        out.append(ret[1])
        
    return s, out


if __name__ == "__main__":
    stat, retval = profile_parallel()
    stat.print_stats()

EDIT: Unfortunately I found out that pstat.Stats objects cannot be used normally with multiprocessing.Queue because it is not pickleable (which is needed for the operation of concurrent.futures). Evidently it normally will store a reference to a file for the purpose of writing statistics to that file, and if none is given, it will by default grab a reference to sys.stdout. We don't actually need that reference however until we actually want to print out the statistics, so we can just give it a temporary value to prevent the pickle error, and then restore an appropriate value later. The following example should be copy-paste-able and run just fine rather than the pseudocode-ish example above.

from multiprocessing import Queue, Process
from cProfile import Profile
from pstats import Stats
import sys

def isprime(x):
    for d in range(2, int(x**.5)):
        if x % d == 0:
            return False
    return True

def foo(retq):
    p = Profile()
    p.enable()
    
    primes = []
    max_n = 2**20
    for n in range(3, max_n):
        if isprime(n):
            primes.append(n)
        
    p.disable()
    retq.put(Stats(p, stream="dummy")) #Dirty hack: set `stream` to something picklable then override later

if __name__ == "__main__":
    q = Queue()
    
    p1 = Process(target=foo, args=(q,))
    p1.start()
    
    p2 = Process(target=foo, args=(q,))
    p2.start()
    
    s1 = q.get()
    s1.stream = sys.stdout #restore original file
    s2 = q.get()
  # s2.stream #if we are just adding this `Stats` object to another the `stream` just gets thrown away anyway.
    
    s1.add(s2) #add up the stats from both child processes.
    s1.print_stats() #s1.stream gets used here, but not before. If you provide a file to write to instead of sys.stdout, it will write to that file)
    
    p1.join()
    p2.join()

Upvotes: 1

Related Questions