Reputation: 67
Does Tensorflow support runtime determine the shape of Tensor?
The problem is to build a Constant tensor in runtime based on the input vector length_q
. The number of columns of the target tensor is the sum of length_q. The code snippet is shown as follows, the length of length_q
is fixed to 64.
T = tf.reduce_sum(length_q, 0)[0]
N = np.shape(length_q)[0]
wm = np.zeros((N, T), dtype=np.float32)
# Something inreletive.
count = 0
for i in xrange(N):
ones = np.ones(length_q[i])
wm[i][count:count+length_q[i]] = ones
count += length_q[i]
return tf.Constant(wm)
Update
I want to create a dynamic Tensor according to the input length_q
. length_q
is some input vector (64*1). The new tensor's shape I want to create depends on the sum of length_q
because in each batch the data in length_q
changes. The current code snippet is as follows:
def some_matrix(length_q):
T = tf.reduce_sum(length_q, 0)[0]
N = np.shape(length_q)[0]
wm = np.zeros((N, T), dtype=np.float32)
count = 0
return wm
def network_inference(length_q):
wm = tf.constant(some_matrix(length_q));
...
And the problem occurs probably because length_q
is the placeholder and doesn't have summation operation. Are there some ways to solve this problem?
Upvotes: 1
Views: 1539
Reputation: 126154
It sounds like the tf.fill()
op is what you need. This op allows you to specify a shape as a tf.Tensor
(i.e. a runtime value) along with a value:
def some_matrix(length_q):
T = tf.reduce_sum(length_q, 0)[0]
N = tf.shape(length_q)[0]
wm = tf.fill([T, N], 0.0)
return wm
Upvotes: 1
Reputation: 1726
Not clear about what you are calculating. If you need to calculate N shape, you can generate ones like this
T = tf.constant(20.0,tf.float32) # tf variable which is reduced sum , 20.0 is example float value
T = tf.cast(T,tf.int32) # columns will be integer only
N = 10 # if numpy integer- assuming np.shape giving 10
# N = length_q.getshape()[0] # if its a tensor, 'lenght_q' replace by your tensor name
wm = tf.ones([N,T],dtype=tf.float32) # N rows and T columns
sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(wm)
Upvotes: 0