Reputation: 809
In some cases, tensorflow seems to be able inspect the values of tensors at graph creation time, and in other cases this fails.
>>> shape = [constant([2])[0], 3]
>>> reshape([1,2,3,4,5,6], shape)
<tf.Tensor 'Reshape_13:0' shape=(2, 3) dtype=int32>
>>> zeros(shape)
<tf.Tensor 'zeros_2:0' shape=(?, 3) dtype=float32>
In the example above, reshape() can see that the tensor passed in as shape has a value of 2, and the resulting output has a shape of (2,3) but zeros() cannot and the static shape is (?,3). What is the reason for the difference?
My colleague posted Determining tensor shapes at time of graph creation in TensorFlow, which is based on the same underlying issue, but he was asking a slightly different question of how to best work with tensorflow to solve this kind of thing, whereas my question is about why tensorflow behaves this way. Is it a bug?
Upvotes: 4
Views: 167
Reputation: 2356
TD;DR:
tf.reshape
can infer the shape of output but tf.zeros
cannot;shape
supports integers (as static/definite) and also Tensors (as dynamic/indefinite) for both functions.Codes are more concrete and more clear:
shape = [tf.constant([2])[0], tf.constant([3])[0]]
print(tf.reshape([1,2,3,4,5,6], shape))
# Tensor("Reshape:0", shape=(?, ?), dtype=int32)
print(tf.zeros(shape))
# Tensor("zeros:0", shape=(?, ?), dtype=float32)
and this:
shape = [tf.constant([5])[0], 3]
print tf.reshape([1,2,3,4,5,6], shape)
# Tensor("Reshape:0", shape=(2, 3), dtype=int32)
# This will cause an InvalidArgumentError at running time!
When using a Tensor
(like tf.constant([2])[0]
) as shape
to create another Tensor
(like tf.zeros(shape)
), the shape is always indefinite at graph creation time. However, tf.reshape()
is different. It can infer the shape of the output using the shape of the input and the given shape (of static part).
In your code, 3
is a static integer and the shape of the input is given([6]
); the shape (2, 3)
is obtained in fact by inferring, instead of being provided. This can be proved in the second part of codes. Although I give a tf.constant([5])
, the shape does not change. (No error at graph creation time but raised an error at running time!)
Upvotes: 2