Reputation: 427
I currently have the following code which I would like to use to provide a "stream" of incrementing integers.
import tensorflow as tf
...
record_count = tf.user_ops.my_custom_op(...) # something I write in C++ of Python
...
my_variable = tf.Variable(0, dtype=dtypes.int64)
my_var_incremented = my_variable.assign_add(math_ops.to_int64(record_count))
queued_increment = tf.train.input.batch((my_variable,), 1)
But the issue is that queued_increment
is just a reference to my_variable
, when I just want to enqueue a copy of my_variable
after incrementing it.
Is this the correct way to go about this, or am I missing something?
Upvotes: 0
Views: 149
Reputation: 126154
The current TensorFlow variables have unfortunate semantics when interacting with other stateful constructs (such as queues). The problem stems from "reference types" (note that my_variable.dtype
is tf.int64_ref
, which means that it is a mutable tensor reference), which most operations—including queues—implicitly "dereference" by creating a "constant" tensor that aliases the mutable buffer. We are in the process of fixing this bug in TensorFlow's memory model for variables, but the change is not in the public API yet.
In the meantime, your best option is to force a copy when inserting the variable into the queue. This easiest solution relies on undocumented behavior, but tf.QueueBase.enqueue_many()
will always copy its values into the queue, even when you enqueue a single element. When using it via tf.train.batch()
, you just need to reshape the variable (e.g. using tf.expand_dims()
) and pass enqueue_many=True
. For example:
my_variable = tf.Variable(0, dtype=dtypes.int64)
# ...
queued_increment = tf.train.batch((tf.expand_dims(my_variable, 1),), 1,
enqueue_many=True)
Upvotes: 1