Reputation: 3472
I am using Tensorflow v1.15. I have a very basic implementation of the TensorArray
given in the following example:
import tensorflow as tf
an_array = tf.TensorArray(dtype=tf.float32, size=5, dynamic_size=True, clear_after_read=False, element_shape=(16, 7, 2))
for i in range(5):
val = tf.random.normal(shape=(16, 7, 2))
an_array.write(i, val)
print(tf.Session().run(val))
tensors = [an_array.read(j) for j in range(5)]
print(tf.Session().run(tensors))
The print
in the for loop does not print all zeros, while the last print
statement does. Why is this happening? Thanks.
Upvotes: 4
Views: 620
Reputation: 59691
tf.TensorArray.write
returns a new tf.TensorArray
where the writing operation has taken place. In general, the output of this function should replace the previous reference to the array:
import tensorflow as tf
an_array = tf.TensorArray(dtype=tf.float32, size=5, dynamic_size=True,
clear_after_read=False, element_shape=(16, 7, 2))
for i in range(5):
val = tf.random.normal(shape=(16, 7, 2))
# Replace tensor array reference with the written one
an_array = an_array.write(i, val)
print(tf.Session().run(val))
tensors = [an_array.read(j) for j in range(5)]
print(tf.Session().run(tensors))
Upvotes: 4