Rishav Bhardwaj
Rishav Bhardwaj

Reputation: 83

Iterating Batches through Tensorflow Dataset Generator

Lets say I have

sequence = np.array([[1],[2],[3],[4],[5]])

I have defined a generator as

def generator():
    for el in sequence:
        yield el

Now, I wish to use from_generator() defined in Tensorflow in order to retrieve the data from the generator.

dataset = tf.data.Dataset().from_generator(generator,
                                       output_types= tf.int64, 
                                       output_shapes=(tf.TensorShape([1])))
iterator = dataset.make_initializable_iterator()
el = iterator.get_next()

In order, to retrieve I have used,

with tf.Session() as sess:
    sess.run(iterator.initializer)
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))

Is there a way to get 'el' using a loop, instead of executing sess.run(el) everytime?

Upvotes: 2

Views: 454

Answers (1)

gorjan
gorjan

Reputation: 5555

This should achieve what you want:

with tf.Session() as sess:
    sess.run(iterator.initializer)
    try:
        while True:
            print(sess.run(el))
    except tf.errors.OutOfRangeError:
        print("Iterating finished")
        pass

Upvotes: 1

Related Questions