Reputation: 427
I wrote this code to filter values from a Dataset
that are <= 6.
import tensorflow as tf
import tensorflow.contrib.data as ds
def make_graph():
inits = []
filter_value = tf.constant([6], dtype=tf.int64)
source = ds.Dataset.range(10)
batched = source.batch(3)
batched_iter = batched.make_initializable_iterator()
batched_next = batched_iter.get_next()
inits.append(batched_iter.initializer)
predicate = tf.less_equal(batched_next, filter_value, name="less_than_filter")
true_coordinates = tf.where(predicate)
reshaped = tf.reshape(true_coordinates, [-1])
# need to turn bools into 1 and 0 elsewhere
found = tf.gather(params=batched_next, indices=reshaped)
return found, inits # prepend final tensor
def run_graph(final_tensor, initializers, rounds):
with tf.Session() as sess:
init_ops = (tf.local_variables_initializer(), tf.global_variables_initializer())
sess.run(init_ops)
summary_writer = tf.summary.FileWriter(graph=sess.graph, logdir=".")
while rounds > 0:
for i in initializers:
sess.run(i)
try:
while True:
final_result = sess.run(final_tensor)
p```pythrint("Got result: {r}".format(r=final_result))
except tf.errors.OutOfRangeError:
print("Got out of range error")
rounds -=1
summary_writer.flush()
def run():
final_tensor, initializers = make_graph()
run_graph(final_tensor=final_tensor,
initializers=initializers,
rounds=1)
if __name__ == "__main__":
run()
However, the results are as follows:
Got result: [0 1 2]
Got result: [3 4 5]
Got result: [6]
Got result: []
Got out of range error
Is there a way to filter this empty Tensor? I tried to brainstorm ways to do this, maybe with a tf.while
loop, but I'm not sure whether I'm missing something or such an operation (i.e. an OpKernel "dropping" an input by not producing output based on its value) is not possible in Tensorflow.
Upvotes: 4
Views: 2693
Reputation: 19123
Keeping only values <= 6 BEFORE batching:
dataset = ds.Dataset.range(10)
dataset = dataset.filter( lambda v : v <= 6 )
dataset = dataset.batch(3)
batched_iter = dataset.make_initializable_iterator()
This will generate batches containing only the data you want. Note that it's generally better to filter out the unwanted data before building the batches. This way, empty tensors will not be generated by the iterator.
Upvotes: 3