Kong
Kong

Reputation: 2410

How to print tensor shape in tensorflow?

How do i print the shape of a tensor given a batch input ? The code below does not work

x_in = tf.identity(x_)
print_x_in = tf.Print(x_in, x_in.get_shape())

init = tf.global_variables_initializer()

# Start a new TF session
sess = tf.Session()

# Run the initializer
sess.run(init)

# feed in batch
sess.run(x_in, feed_dict={x_: x[1:10,:,:,:]})

Upvotes: 0

Views: 2394

Answers (2)

Zoe
Zoe

Reputation: 1410

Firstly, you do not define x_. You need a placeholder, along the lines of

x_ = tf.placeholder(shape=[None, shape[0],shape[1],shape[2],dtype=tf.float32)

Then you can feed in the values x for x_.

Once in your session, you evaluate the tensor

x_out = sess.run(x_in, feed_dict={x_: x[1:10,:,:,:]})

which you can then print.

print(np.shape(x_out))

Upvotes: 3

Avijit Dasgupta
Avijit Dasgupta

Reputation: 2065

I do the following:

x_in = tf.identity(x_)
with tf.Session() as sess:
    print sess.run(tf.shape(x_in))

If you are not looking for this, then please give us some context.

Upvotes: 0

Related Questions