Reputation: 1012
I was trying to count word frequencies by using multipleprocessing
with a shared dict
. I wrote a simple Python code snippet for some initial test:
from multiprocessing import Manager, Pool
def foo(num):
try:
d[num] += 1
except KeyError:
d[num] = 1
d = Manager().dict()
pool = Pool(processes=2, maxtasksperchild=100)
tasks = [1] * 1001 + [2] * 2000 + [3] * 1300
pool.map_async(foo, tasks)
pool.close()
pool.join()
print len(tasks)
print d
However, the total number of frequencies in d
does not match those in tasks
. It seems to me that d
is not well synchronized, but I have no clue about why that happens and how to fix that. May someone provide me some help here?
Upvotes: 1
Views: 206
Reputation: 365607
You've got a race condition here:
try:
d[num] += 1
except KeyError:
d[num] = 1
Let's say that task 1 tries to do d[1] += 1
, but d[1]
is empty, so it gets a KeyError
. Now task 2, on the other core, tries to do d[1] += 1
, but d[1]
is still empty, so it also gets a KeyError
. So, now both task 1 and task 2 will try to set d[1] = 1
, and they'll both succeed, so d[1]
is now 1
, and you've lost 1 increment.
Even worse, let's say that before task 1 gets around to setting d[1] = 1
, tasks 3-10 all run on the other core and finish incrementing d[1]
all the way up to 9
. Then task 1 comes in and sets it back to 1
, and you've lost 9 increments.
You might think you could solve this by just preinitializing d = {1: 0, 2: 0, 3: 0}
and leaving out the try
/except
. But that still won't work. Because even d[1] += 1
isn't atomic. Python compiles that into, effectively, three separate operations: tmp = d.__getitem__(1)
, tmp += 1
, d.__setitem__(1, tmp)
.
So, task 1 could fetch the existing 0 from the shared dict, increment it to 1, and meanwhile task 2 has fetched the existing 0, incremented it to 1, and now they both go to store 1
and both succeed. And, again, you can see how this extends to losing large batches of increments rather than just one.
Any non-atomic operations on shared data have to be explicitly synchronized. This is explained in Synchronization between processes and Sharing state between processes in the docs.
The simplest fix here (although obviously not the best, because it ends up serializing all of your access) is:
from multiprocessing import Manager, Pool, Lock
lock = Lock()
def foo(num):
with lock:
try:
d[num] += 1
except KeyError:
d[num] = 1
If you want to get fancy and make this more efficient, you're going to have to learn about how shared-memory threading and synchronization; there's way too much to explain in a single StackOverflow answer.
Upvotes: 2