user1371314
user1371314

Reputation: 832

Tensorflow indexing into python list during tf.while_loop

I have this annoying problem and i dont know how to solve it.

I am reading in batches of data from a CSV using a dataset reader and am wanting to gather certain columns. The reader returns a tuple of tensors and, depending on which reader i use, columns are either indexed via integer or string.

I can easily enough do a for loop in python and slice the columns I want however I am wanting to do this in a tf.while_loop to take advantage of parallel execution.

This is where my issue lies - the iterator in the while loop is tensor based and i cannot use this to index into my dataset. If i try and evaluate it I get an error about the session not being the same etc etc

How can i use a while loop (or a map function) and have the function be able to index into a python list/dict without evaluating or running the iterator tensor?

Simple example:

        some_data = [1,2,3,4,5]

        x = tf.constant(0)
        y = len(some_data)
        c = lambda x: tf.less(x, y)
        b = lambda x: some_data[x] <--- You cannot index like this!

        tf.while_loop(c, b, [x])

Upvotes: 2

Views: 438

Answers (1)

Mohan Radhakrishnan
Mohan Radhakrishnan

Reputation: 3197

Does this fit your requirement somewhat ? It does nothing apart from print the value.

import tensorflow as tf
from tensorflow.python.framework import tensor_shape

some_data = [11,222,33,4,5,6,7,8]

def func( v ):
    print (some_data[v])
    return some_data[v]

with tf.Session() as sess:
    r = tf.while_loop(
        lambda i, v: i < 4,
        lambda i, v: [i + 1, tf.py_func(func, [i], [tf.int32])[0]],
        [tf.constant(0), tf.constant(2, tf.int32)],
        [tensor_shape.unknown_shape(), tensor_shape.unknown_shape()])

    r[1].eval()

It prints

11 4 222 33

The order changes everytime but I guess tf.control_dependencies may be useful to control that.

Upvotes: 1

Related Questions