Reputation: 159
I have a function that does calculations with the help of a mutable object as in the example below:
def fun(obj: MutableObject, input_a, input_b):
obj.a = input_a
return obj.do_stuff(input_b)
I need to do this many times and am currently using a for loop as seen below:
obj = MutableObject()
output = []
for input_a, input_b in inputs:
output.append(fun(obj, input_a, input_b))
To speed the process up I want to use python multiprocessing and perform multiple calls of fun
in parallel. A common way I have seen this done is by using multiproccesing.Pool
to map over a list. The problem for me with such an implementation is that I have the mutable object that needs to be shared between the processes. I would like each process to have access to a clone of the object without creating unnecessarily many clones.
A naive attempt would be to copy the object for each input:
import multiprocessing
import copy
obj = MutableObject()
def map_fun(arg):
input_a, input_b = arg
temp_obj = copy.deepcopy(obj)
return fun(temp_obj, input_a, input_b)
pool = multiprocessing.Pool()
outputs = pool.map(map_fun, inputs)
But that seems wasteful, both on CPU and memory.
Is there any way I can create a temporary pool of object copies, one for each parallel process, instead of creating one for each input pair?
EDIT:
It was pointed out in a comment that memory probably won't be an issue as the garbage collection will clean up the unused copies. I am still worried that copying will require lots of recourses though as my MutableObject in reality is a Keras model (neural network) that can be quite large.
Upvotes: 1
Views: 184
Reputation: 5954
Here is a solution which drops the pool and manages the threads itself, ensuring only one object for each process:
from multiprocessing import Process, cpu_count, JoinableQueue
class MuteableObj:
def method(self, data):
data["processed"] = True
return data
class Worker(Process):
def __init__(self, task_queue, result_queue):
super().__init__()
print("Started", self._name)
self.task_queue = task_queue
self.result_queue = result_queue
self._obj = MuteableObj()
self._open = True
def run(self):
while self._open:
task = self.task_queue.get()
print(f"Processing {task['id']}")
result = self._obj.method(task)
self.task_queue.task_done()
self.result_queue.put(result)
print("over")
def terminate(self):
print("Stopped", self._name)
super().terminate()
task_queue = JoinableQueue()
result_queue = JoinableQueue()
NTHREADS = cpu_count()
for i in range(200):
task_queue.put(dict(id=i))
threads = [Worker(task_queue, result_queue) for i in range(NTHREADS)]
for t in threads:
t.start()
task_queue.join()
for t in threads:
t.terminate()
results = []
while not result_queue.empty():
results.append(result_queue.get())
print(results)
Firstly we have a mock of your muteable object, here just a class with one method we care about.
We subclass Process
ourself, and give each process one object at initialisation. Then we fill a JoinableQueue
with the required tasks, and wait until they are all done, when we get all the results out of another queue (although we could actually use a list and Lock
, but I think this is easier to read).
Note that results are not guaranteed to be in the order in which they were sent. If this matters, you should give them an id as I have here, and sort by the id.
If you need to the pool to run indefinitely and do specific things with each result, you probably want to write a callback, move the join()
to the end of the code (since it blocks until all tasks are processed) and then have a loop which waits for results and calls your callback:
from time import sleep
while running:
while not results_queue.empty():
callback(results_queue.get())
while results_queue.empty():
sleep(0.1)
In this case I would wrap all this up in another class, called something like TaskRunner
, to keep state (like running
) isolated.
Incidentally I first came across this recipe on SO years ago, and I've been using it ever since.
Upvotes: 1