eqzx
eqzx

Reputation: 5599

Is tf.gradients thread-safe?

I have several calls to tf.gradients that each take some time, thus I would like to concurrently call tf.gradients. However, I receive one of several errors when I try to do so in my graph. I suspect it is not thread-safe, but have not been able to reproduce the error with a MWE. I tried using both pathos.pools.ThreadPool and pathos.pools.ProcessPool in both my MWE and real code - only my real code fails. Here is the MWE that I tried:

from pathos.pools import ThreadPool, ProcessPool
import tensorflow as tf
import numpy as np

Xs = [tf.cast(np.random.random((10,10)), dtype=tf.float64) for i in range(3)]
Ys = [Xs[0]*Xs[1]*Xs[2], Xs[0]/Xs[1]*Xs[2], Xs[0]/Xs[1]/Xs[2]]

def compute_grad(YX):
    return tf.gradients(YX[0], YX[1])

tp = ThreadPool(3)
res = tp.map(compute_grad, zip(Ys, Xs))
print(res)

Here's a partial traceback I encountered when trying my real code. This is the ThreadPool version.

File "pathos/threading.py", line 134, in map
    return _pool.map(star(f), zip(*args)) # chunksize
  File "multiprocess/pool.py", line 260, in map
    return self._map_async(func, iterable, mapstar, chunksize).get()
  File "multiprocess/pool.py", line 608, in get
    raise self._value
  File "multiprocess/pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
  File "multiprocess/pool.py", line 44, in mapstar
    return list(map(*args))
  File "pathos/helpers/mp_helper.py", line 15, in <lambda>
    func = lambda args: f(*args)
  File "my_code.py", line 939, in gradients_with_index
    return (tf.gradients(Y, variables), b_idx)
  File "tensorflow/python/ops/gradients_impl.py", line 448, in gradients
    colocate_gradients_with_ops)
  File "tensorflow/python/ops/gradients_impl.py", line 188, in _PendingCount
    between_op_list, between_ops, colocate_gradients_with_ops)
  File "tensorflow/python/ops/control_flow_ops.py", line 1288, in MaybeCreateControlFlowState
    loop_state.AddWhileContext(op, between_op_list, between_ops)
  File "tensorflow/python/ops/control_flow_ops.py", line 1103, in AddWhileContext
    grad_state = GradLoopState(forward_ctxt, outer_grad_state)
  File "tensorflow/python/ops/control_flow_ops.py", line 737, in __init__
    cnt, outer_grad_state)
  File "tensorflow/python/ops/control_flow_ops.py", line 2282, in AddBackPropLoopCounter
    merge_count = merge([enter_count, enter_count])[0]
  File "tensorflow/python/ops/control_flow_ops.py", line 404, in merge
    return gen_control_flow_ops._merge(inputs, name)
  File "tensorflow/python/ops/gen_control_flow_ops.py", line 150, in _merge
    result = _op_def_lib.apply_op("Merge", inputs=inputs, name=name)
  File "tensorflow/python/framework/op_def_library.py", line 767, in apply_op
    op_def=op_def)
  File "tensorflow/python/framework/ops.py", line 2506, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "tensorflow/python/framework/ops.py", line 1273, in __init__
    self._control_flow_context.AddOp(self)
  File "tensorflow/python/ops/control_flow_ops.py", line 2147, in AddOp
    self._AddOpInternal(op)
  File "tensorflow/python/ops/control_flow_ops.py", line 2177, in _AddOpInternal
    self._MaybeAddControlDependency(op)
  File "tensorflow/python/ops/control_flow_ops.py", line 2204, in _MaybeAddControlDependency
    op._add_control_input(self.GetControlPivot().op)
AttributeError: 'NoneType' object has no attribute 'op'

Here is another traceback. Note the error is different

Traceback (most recent call last):
  File "tensorflow/python/ops/control_flow_ops.py", line 869, in AddForwardAccumulator
    enter_acc = self.forward_context.AddValue(acc)
  File "tensorflow/python/ops/control_flow_ops.py", line 2115, in AddValue
    self._outer_context.AddInnerOp(enter.op)
  File "tensorflow/python/framework/ops.py", line 3355, in __exit__
    self._graph._pop_control_dependencies_controller(self)
  File "tensorflow/python/framework/ops.py", line 3375, in _pop_control_dependencies_controller
    assert self._control_dependencies_stack[-1] is controller
AssertionError

The ProcessPool version encountered the error:

_pickle.PicklingError: Can't pickle <class 'tensorflow.python.util.tf_should_use._add_should_use_warning.<locals>.TFShouldUseWarningWrapper'>: it's not found as tensorflow.python.util.tf_should_use._add_should_use_warning.<locals>.TFShouldUseWarningWrapper

Upvotes: 1

Views: 296

Answers (1)

mrry
mrry

Reputation: 126184

The tf.gradients() function is not thread-safe. It makes a sequence of complicated and non-atomic modifications to your graph, and these are not protected by locks. In particular, it seems that using tf.gradients() on a graph that contains control flow operations (such as tf.while_loop()) is more likely to run into problems if you run it concurrently.

Note that it is unlikely that issuing parallel calls to tf.gradients() would speed it up—even if it were implemented in a thread-safe manner. The function performs no I/O and does not call any native methods that release Python's GIL, so the execution would most likely be serialized. Implementing multiprocessing-based parallelism would require additional system calls for accessing the shared graph (and acquiring/releasing locks), so it is unlikely that this would be faster.

Upvotes: 3

Related Questions