Sam
Sam

Reputation: 427

Incrementing variable as a copy in Tensorflow

I currently have the following code which I would like to use to provide a "stream" of incrementing integers.

import tensorflow as tf
...
record_count = tf.user_ops.my_custom_op(...) # something I write in C++ of Python
...
my_variable = tf.Variable(0, dtype=dtypes.int64)
my_var_incremented = my_variable.assign_add(math_ops.to_int64(record_count))
queued_increment = tf.train.input.batch((my_variable,), 1)

But the issue is that queued_increment is just a reference to my_variable, when I just want to enqueue a copy of my_variable after incrementing it.

Is this the correct way to go about this, or am I missing something?

Upvotes: 0

Views: 149

Answers (1)

mrry
mrry

Reputation: 126154

The current TensorFlow variables have unfortunate semantics when interacting with other stateful constructs (such as queues). The problem stems from "reference types" (note that my_variable.dtype is tf.int64_ref, which means that it is a mutable tensor reference), which most operations—including queues—implicitly "dereference" by creating a "constant" tensor that aliases the mutable buffer. We are in the process of fixing this bug in TensorFlow's memory model for variables, but the change is not in the public API yet.

In the meantime, your best option is to force a copy when inserting the variable into the queue. This easiest solution relies on undocumented behavior, but tf.QueueBase.enqueue_many() will always copy its values into the queue, even when you enqueue a single element. When using it via tf.train.batch(), you just need to reshape the variable (e.g. using tf.expand_dims()) and pass enqueue_many=True. For example:

my_variable = tf.Variable(0, dtype=dtypes.int64)
# ...
queued_increment = tf.train.batch((tf.expand_dims(my_variable, 1),), 1,
                                  enqueue_many=True)

Upvotes: 1

Related Questions