Reputation: 3253
I am trying to understand tf code and for this I am printing out shapes of tensors. For the following code
print(x.shape)
print(tf.shape(x))
I get output
(?, 32, 32, 3)
Tensor("input/Shape:0", shape=(4,), dtype=int32)
It does not make a lot of sense. Based on what I found online tf.shape(x) can be used to dynamically get the size for the batch. But it gives rather wrong output - 4. I am not sure where this (4,)
is coming from and how to get the right value for my tensor.
Upvotes: 0
Views: 165
Reputation: 6166
In fact, the two results are the same. 4
is the shape of (?,32,32,3)
.
x.shape()
returns a tuple, and you can get shape without sess.run()
. You can use as_list()
to convert it into a list.
tf.shape(x)
returns a tensor, and you need to run sess.run()
to get the the actual number.
An example:
import tensorflow as tf
import numpy as np
x = tf.placeholder(shape=(None,32,32,3),dtype=tf.float32)
print(x.shape)
print(tf.shape(x))
dim = tf.shape(x)
dim0 = tf.shape(x)[0]
with tf.Session()as sess:
dim,dim0 = sess.run([dim,dim0],feed_dict={x:np.random.uniform(size=(100,32,32,3))})
print(dim)
print(dim0)
#print
(?, 32, 32, 3)
Tensor("Shape:0", shape=(4,), dtype=int32)
[100 32 32 3]
100
Upvotes: 1