Reputation: 832
I have this annoying problem and i dont know how to solve it.
I am reading in batches of data from a CSV using a dataset reader and am wanting to gather certain columns. The reader returns a tuple of tensors and, depending on which reader i use, columns are either indexed via integer or string.
I can easily enough do a for loop in python and slice the columns I want however I am wanting to do this in a tf.while_loop to take advantage of parallel execution.
This is where my issue lies - the iterator in the while loop is tensor based and i cannot use this to index into my dataset. If i try and evaluate it I get an error about the session not being the same etc etc
How can i use a while loop (or a map function) and have the function be able to index into a python list/dict without evaluating or running the iterator tensor?
Simple example:
some_data = [1,2,3,4,5]
x = tf.constant(0)
y = len(some_data)
c = lambda x: tf.less(x, y)
b = lambda x: some_data[x] <--- You cannot index like this!
tf.while_loop(c, b, [x])
Upvotes: 2
Views: 438
Reputation: 3197
Does this fit your requirement somewhat ? It does nothing apart from print the value.
import tensorflow as tf
from tensorflow.python.framework import tensor_shape
some_data = [11,222,33,4,5,6,7,8]
def func( v ):
print (some_data[v])
return some_data[v]
with tf.Session() as sess:
r = tf.while_loop(
lambda i, v: i < 4,
lambda i, v: [i + 1, tf.py_func(func, [i], [tf.int32])[0]],
[tf.constant(0), tf.constant(2, tf.int32)],
[tensor_shape.unknown_shape(), tensor_shape.unknown_shape()])
r[1].eval()
It prints
11 4 222 33
The order changes everytime but I guess tf.control_dependencies
may be useful to control that.
Upvotes: 1