user288609
user288609

Reputation: 13025

regarding printing the shape of tensor

I test the following code script

import tensorflow as tf

a, b, c = 2, 3, 4
x = tf.Variable(tf.random_normal([a, b, c], mean=0.0, stddev=1.0,    dtype=tf.float32))
s = tf.shape(x)
print(s)

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
print(sess.run(s))

Running the code gets the following result

Tensor("Shape:0", shape=(3,), dtype=int32)
[2 3 4]

Looks like only the second print gives the readable format. What does the first print really do or how to understand the first output?

Upvotes: 0

Views: 648

Answers (1)

mrry
mrry

Reputation: 126154

The call to s = tf.shape(x) defines a symbolic (but very simple) TensorFlow computation that only executes when you call sess.run(s).

When you execute print(s) Python will print everything that TensorFlow knows about the tensor s without actually evaluating it. Since it is the output of a tf.shape() op, TensorFlow knows that it has type tf.int32, and TensorFlow can also infer that it is a vector of length 3 (because x is statically known to be a 3-D tensor from the variable definition).

Note that in many cases you can obtain more shape information without calling a tensor by printing the static shape of a particular tensor, using the Variable.get_shape() method (similar to its cousin Tensor.get_shape()):

# Print the statically known shape of `x`.
print(x.get_shape())
# ==> "(2, 3, 4)"

Upvotes: 1

Related Questions