Daniel S.
Daniel S.

Reputation: 51

Going from a TensorArray to a Tensor

Given a TensorArray with a fixed size and entries with uniform shapes, I want to go to a Tensor containing the same values, simply by having the index dimension of the TensorArray as a regular axis.

TensorArrays have a method called "gather" which purportedly should do just that. And, in fact, the following example works:

array = tf.TensorArray(tf.int32, size=3)
array.write(0, 10)
array.write(1, 20)
array.write(2, 30)

gathered = array.gather([0, 1, 2])

"gathered" then yields the desired Tensor:

tf.Tensor([10 20 30], shape=(3,), dtype=int32)

Unfortunately, this stops working when wrapping it inside a tf.function, like so:

@tf.function
def func():
    array = tf.TensorArray(tf.int32, size=3)
    array.write(0, 10)
    array.write(1, 20)
    array.write(2, 30)

    gathered = array.gather([0, 1, 2])
    return gathered

tensor = func()

"tensor" then wrongly yields the following Tensor:

tf.Tensor([0 0 0], shape=(3,), dtype=int32)

Do you have an explanation for this, or can you suggest an alternative way to go from a TensorArray to a Tensor inside a tf.function?

Upvotes: 3

Views: 2110

Answers (2)

FQ912
FQ912

Reputation: 11

instead of array.gather(), try using array.stack(), it'll return a Tensor from the TensorArray

Upvotes: 0

Gili
Gili

Reputation: 90023

Per https://github.com/tensorflow/tensorflow/issues/30409#issuecomment-508962873 you have to:

Replace arr.write(j, t) with arr = arr.write(j, t)

The issue is that tf.function executes as a graph. In eager mode the array will be updated (as a convenience), but you're really meant to use the return value to chain operations: https://www.tensorflow.org/api_docs/python/tf/TensorArray#returns_6

Upvotes: 2

Related Questions