Sam
Sam

Reputation: 427

filtering "empty" values from Tensorflow

I wrote this code to filter values from a Dataset that are <= 6.

import tensorflow as tf
import tensorflow.contrib.data as ds

def make_graph():
    inits = []
    filter_value = tf.constant([6], dtype=tf.int64)
    source = ds.Dataset.range(10)
    batched = source.batch(3)
    batched_iter = batched.make_initializable_iterator()
    batched_next = batched_iter.get_next()
    inits.append(batched_iter.initializer)
    predicate = tf.less_equal(batched_next, filter_value, name="less_than_filter")
    true_coordinates = tf.where(predicate)
    reshaped = tf.reshape(true_coordinates, [-1])
    # need to turn bools into 1 and 0 elsewhere
    found = tf.gather(params=batched_next, indices=reshaped)

    return found, inits # prepend final tensor

def run_graph(final_tensor, initializers, rounds):
    with tf.Session() as sess:
        init_ops = (tf.local_variables_initializer(), tf.global_variables_initializer())
        sess.run(init_ops)
        summary_writer = tf.summary.FileWriter(graph=sess.graph, logdir=".")
        while rounds > 0:
            for i in initializers:
                sess.run(i)
            try:
                while True:
                    final_result = sess.run(final_tensor)
                    p```pythrint("Got result: {r}".format(r=final_result))
            except tf.errors.OutOfRangeError:
                print("Got out of range error")
            rounds -=1

        summary_writer.flush()

def run():
    final_tensor, initializers = make_graph()
    run_graph(final_tensor=final_tensor,
              initializers=initializers,
              rounds=1)

if __name__ == "__main__":
    run()

However, the results are as follows:

Got result: [0 1 2]
Got result: [3 4 5]
Got result: [6]
Got result: []
Got out of range error

Is there a way to filter this empty Tensor? I tried to brainstorm ways to do this, maybe with a tf.while loop, but I'm not sure whether I'm missing something or such an operation (i.e. an OpKernel "dropping" an input by not producing output based on its value) is not possible in Tensorflow.

Upvotes: 4

Views: 2693

Answers (1)

GPhilo
GPhilo

Reputation: 19123

Keeping only values <= 6 BEFORE batching:

dataset = ds.Dataset.range(10)
dataset = dataset.filter( lambda v : v <= 6 )
dataset = dataset.batch(3)
batched_iter = dataset.make_initializable_iterator()

This will generate batches containing only the data you want. Note that it's generally better to filter out the unwanted data before building the batches. This way, empty tensors will not be generated by the iterator.

Upvotes: 3

Related Questions