spreisel
spreisel

Reputation: 391

How to index a list with a TensorFlow tensor?

Assume a list with non concatenable objects which needs to be accessed via a look up table. So the list index will be a tensor object but this is not possible.

 tf_look_up = tf.constant(np.array([3, 2, 1, 0, 4]))
 index = tf.constant(2)
 list = [0,1,2,3,4]

 target = list[tf_look_up[index]]

This will bring out the following error message.

 TypeError: list indices must be integers or slices, not Tensor

Is the a way/workaround to index lists with tensors?

Upvotes: 18

Views: 17171

Answers (2)

soloice
soloice

Reputation: 1040

tf.gather is designed for this purpose.

Simply run tf.gather(list, tf_look_up[index]), you'll get what you want.

Upvotes: 15

jbird
jbird

Reputation: 506

Tensorflow actually has support for a HashTable. See the documentation for more details.

Here, what you could do is the following:

table = tf.contrib.lookup.HashTable(
    tf.contrib.lookup.KeyValueTensorInitializer(tf_look_up, list), -1)

Then just get the desired input by running

target = table.lookup(index)

Note that -1 is the default value if the key is not found. You may have to add key_dtype and value_dtype to the constructor depending on the configuration of your tensors.

Upvotes: 2

Related Questions