Reputation: 45
If I use multiple elements from a tf.data.Dataset dataset to build the graph, and then evaluate the graph later, it seems the order the element from the Dataset is undefined. As an example, the following code snippet
import tensorflow as tf
dataset = tf.data.Dataset.range(5)
iterator = dataset.make_one_shot_iterator()
print 'build graph and then eval'
keep = []
for i in range(5):
keep.append(iterator.get_next())
with tf.Session() as sess:
keep_eval = sess.run(keep)
print keep_eval
print 'eval each element'
with tf.Session() as sess:
for i in range(5):
print sess.run(iterator.get_next()),
will result in output like:
build graph and then eval
[3 0 1 4 2]
eval each element
0 1 2 3 4
Also, each run will yield different "build graph and then eval". I would expect "build graph and then eval" to be ordered as well like "eval each element". Can anyone explain why this happens?
Upvotes: 1
Views: 1737
Reputation: 681
From the TensorFlow FAQs here
The individual ops have parallel implementations, using multiple cores in a CPU, or multiple threads in a GPU.
So your "build graph then eval" call runs in parallel for each element in the list, which is why the numbers are in random order, while the for loop forces one call to be run after another, so its serial. You can verify by timing both, the first one should be fast, the for loop will be slower.
Upvotes: 1
Reputation: 126184
The order of a tf.data.Dataset
is defined and deterministic (unless you add a non-deterministic Dataset.shuffle()
).
However, your two loops build different graphs, which accounts for the difference:
The "build graph and then eval" part creates a list of five iterator.get_next()
operations and runs the five operations in parallel. Because these operations run in parallel, they may produce results in different order.
The "eval each element" part also creates five iterator.get_next()
operations, but it runs them sequentially, so you always get the results in the expected order.
Note that we do not recommend calling iterator.get_next()
in a loop, because it creates a new operation on each call, which gets added to the graph, and consumes memory. Instead, when you loop over a Dataset
, try to use the following pattern:
dataset = tf.data.Dataset.range(5)
iterator = dataset.make_one_shot_iterator()
# Call `iterator.get_next()` once and use the result in each iteration.
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(5):
print sess.run(next_element)
Upvotes: 3