anon
anon

Reputation: 417

How to get the value from a tensor

Here's my setup:

indices = tf.placeholder(tf.int32, shape=[2])
weights = tf.Variable(tf.random_normal([100000, 3], stddev=0.35))

def objective(indices, weights):
    idx1 = indices[0]; idx2 = indices[1] #extract two indices
    mask = np.zeros(weights.shape.as_list()[0]) #builds a mask for some tensor "weights"
    mask[idx1] = 1 # don't ask why I want to do this. I just do.
    mask[idx2] = 1
    obj = tf.reduce_sum(tf.multiply(weights[idx1], weights[idx2]))
    return obj

optimizer = tf.train.GradientDescentOptimizer(0.01)

obj = objective(indices, weights)
trainer = optimizer.minimize(obj)


with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
   sess.run([trainer, obj], feed_dict={indices=[100, 1000]})

So the point is that I have some tensor, and I take a slice of it which corresponds to an index in my mask. This index is a tf.strided_slice. I want to index my mask with idx1 and idx2, as both evaluate to be ints.

But idx1 and idx2 are not ints but tensors, so the obj = objective(indices, weights) call leads to an error.

How can I get the code to work?

Upvotes: 2

Views: 497

Answers (1)

pfm
pfm

Reputation: 6328

You can make use of a combination of tf.SparseTensor and tf.sparse_tensor_to_dense to achieve what you want:

import numpy as np
import tensorflow as tf

indices = tf.placeholder(tf.int64, shape=[2])
weights = tf.Variable(tf.random_normal([5, 3], stddev=0.35))

def objective(indices, weights):
     idx1 = indices[0]; idx2 = indices[1] #extract two indices
     mask = np.zeros(weights.shape.as_list()[0]) #builds a mask for some tensor "weights"
     mask_ones = tf.SparseTensor(tf.reshape(indices, [-1,1]), [1, 1], mask.shape) # Stores the 1s used in the mask
     mask = mask + tf.sparse_tensor_to_dense(mask_ones) # Set the mask
     obj = tf.reduce_sum(tf.multiply(weights[idx1], weights[idx2]))
     return obj, mask

obj, mask = objective(indices, weights)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run([weights, obj, mask], feed_dict={indices:[0, 4]}))

[array([[...]], dtype=float32), 0.0068909675, array([1., 0., 0., 0., 1.], dtype=int32)]

Upvotes: 1

Related Questions