Reputation: 51
Given a TensorArray with a fixed size and entries with uniform shapes, I want to go to a Tensor containing the same values, simply by having the index dimension of the TensorArray as a regular axis.
TensorArrays have a method called "gather" which purportedly should do just that. And, in fact, the following example works:
array = tf.TensorArray(tf.int32, size=3)
array.write(0, 10)
array.write(1, 20)
array.write(2, 30)
gathered = array.gather([0, 1, 2])
"gathered" then yields the desired Tensor:
tf.Tensor([10 20 30], shape=(3,), dtype=int32)
Unfortunately, this stops working when wrapping it inside a tf.function, like so:
@tf.function
def func():
array = tf.TensorArray(tf.int32, size=3)
array.write(0, 10)
array.write(1, 20)
array.write(2, 30)
gathered = array.gather([0, 1, 2])
return gathered
tensor = func()
"tensor" then wrongly yields the following Tensor:
tf.Tensor([0 0 0], shape=(3,), dtype=int32)
Do you have an explanation for this, or can you suggest an alternative way to go from a TensorArray to a Tensor inside a tf.function?
Upvotes: 3
Views: 2110
Reputation: 11
instead of array.gather()
, try using array.stack()
, it'll return a Tensor
from the TensorArray
Upvotes: 0
Reputation: 90023
Per https://github.com/tensorflow/tensorflow/issues/30409#issuecomment-508962873 you have to:
Replace arr.write(j, t) with arr = arr.write(j, t)
The issue is that tf.function executes as a graph. In eager mode the array will be updated (as a convenience), but you're really meant to use the return value to chain operations: https://www.tensorflow.org/api_docs/python/tf/TensorArray#returns_6
Upvotes: 2