Reputation: 5599
I have several calls to tf.gradients
that each take some time, thus I would like to concurrently call tf.gradients
. However, I receive one of several errors when I try to do so in my graph. I suspect it is not thread-safe, but have not been able to reproduce the error with a MWE. I tried using both pathos.pools.ThreadPool
and pathos.pools.ProcessPool
in both my MWE and real code -
only my real code fails. Here is the MWE that I tried:
from pathos.pools import ThreadPool, ProcessPool
import tensorflow as tf
import numpy as np
Xs = [tf.cast(np.random.random((10,10)), dtype=tf.float64) for i in range(3)]
Ys = [Xs[0]*Xs[1]*Xs[2], Xs[0]/Xs[1]*Xs[2], Xs[0]/Xs[1]/Xs[2]]
def compute_grad(YX):
return tf.gradients(YX[0], YX[1])
tp = ThreadPool(3)
res = tp.map(compute_grad, zip(Ys, Xs))
print(res)
Here's a partial traceback I encountered when trying my real code. This is the ThreadPool
version.
File "pathos/threading.py", line 134, in map
return _pool.map(star(f), zip(*args)) # chunksize
File "multiprocess/pool.py", line 260, in map
return self._map_async(func, iterable, mapstar, chunksize).get()
File "multiprocess/pool.py", line 608, in get
raise self._value
File "multiprocess/pool.py", line 119, in worker
result = (True, func(*args, **kwds))
File "multiprocess/pool.py", line 44, in mapstar
return list(map(*args))
File "pathos/helpers/mp_helper.py", line 15, in <lambda>
func = lambda args: f(*args)
File "my_code.py", line 939, in gradients_with_index
return (tf.gradients(Y, variables), b_idx)
File "tensorflow/python/ops/gradients_impl.py", line 448, in gradients
colocate_gradients_with_ops)
File "tensorflow/python/ops/gradients_impl.py", line 188, in _PendingCount
between_op_list, between_ops, colocate_gradients_with_ops)
File "tensorflow/python/ops/control_flow_ops.py", line 1288, in MaybeCreateControlFlowState
loop_state.AddWhileContext(op, between_op_list, between_ops)
File "tensorflow/python/ops/control_flow_ops.py", line 1103, in AddWhileContext
grad_state = GradLoopState(forward_ctxt, outer_grad_state)
File "tensorflow/python/ops/control_flow_ops.py", line 737, in __init__
cnt, outer_grad_state)
File "tensorflow/python/ops/control_flow_ops.py", line 2282, in AddBackPropLoopCounter
merge_count = merge([enter_count, enter_count])[0]
File "tensorflow/python/ops/control_flow_ops.py", line 404, in merge
return gen_control_flow_ops._merge(inputs, name)
File "tensorflow/python/ops/gen_control_flow_ops.py", line 150, in _merge
result = _op_def_lib.apply_op("Merge", inputs=inputs, name=name)
File "tensorflow/python/framework/op_def_library.py", line 767, in apply_op
op_def=op_def)
File "tensorflow/python/framework/ops.py", line 2506, in create_op
original_op=self._default_original_op, op_def=op_def)
File "tensorflow/python/framework/ops.py", line 1273, in __init__
self._control_flow_context.AddOp(self)
File "tensorflow/python/ops/control_flow_ops.py", line 2147, in AddOp
self._AddOpInternal(op)
File "tensorflow/python/ops/control_flow_ops.py", line 2177, in _AddOpInternal
self._MaybeAddControlDependency(op)
File "tensorflow/python/ops/control_flow_ops.py", line 2204, in _MaybeAddControlDependency
op._add_control_input(self.GetControlPivot().op)
AttributeError: 'NoneType' object has no attribute 'op'
Here is another traceback. Note the error is different
Traceback (most recent call last):
File "tensorflow/python/ops/control_flow_ops.py", line 869, in AddForwardAccumulator
enter_acc = self.forward_context.AddValue(acc)
File "tensorflow/python/ops/control_flow_ops.py", line 2115, in AddValue
self._outer_context.AddInnerOp(enter.op)
File "tensorflow/python/framework/ops.py", line 3355, in __exit__
self._graph._pop_control_dependencies_controller(self)
File "tensorflow/python/framework/ops.py", line 3375, in _pop_control_dependencies_controller
assert self._control_dependencies_stack[-1] is controller
AssertionError
The ProcessPool
version encountered the error:
_pickle.PicklingError: Can't pickle <class 'tensorflow.python.util.tf_should_use._add_should_use_warning.<locals>.TFShouldUseWarningWrapper'>: it's not found as tensorflow.python.util.tf_should_use._add_should_use_warning.<locals>.TFShouldUseWarningWrapper
Upvotes: 1
Views: 296
Reputation: 126184
The tf.gradients()
function is not thread-safe. It makes a sequence of complicated and non-atomic modifications to your graph, and these are not protected by locks. In particular, it seems that using tf.gradients()
on a graph that contains control flow operations (such as tf.while_loop()
) is more likely to run into problems if you run it concurrently.
Note that it is unlikely that issuing parallel calls to tf.gradients()
would speed it up—even if it were implemented in a thread-safe manner. The function performs no I/O and does not call any native methods that release Python's GIL, so the execution would most likely be serialized. Implementing multiprocessing
-based parallelism would require additional system calls for accessing the shared graph (and acquiring/releasing locks), so it is unlikely that this would be faster.
Upvotes: 3