Reputation: 318
I'm starting using Dataset API to replace the feed_dict system.
However, after creating your Dataset pipeline, how can you feed the Dataset's data to the model without using feed_dict ?
First, I created a one shot iterator. But in this case, you need to use feed_dict to provide the data coming from your iterator to the model.
Secondly, I tried to create my dataset directly from a tf.placeholder and then use a initializable_iterator. But here again, I don't understand how to get rid of feed_dict. In addition, I don't understand what's the purpose of this kind of dataset based on plaeholders.
My very basic model:
x = tf.placeholder(tf.float32, [None, 2])
dense = tf.layers.dense(x, 1)
init_dense = tf.global_variables_initializer()
My data:
np_data = np.random.sample((100,2))
Method 1:
dataset = tf.data.Dataset.from_tensor_slices(np_data)
iterator = dataset.make_one_shot_iterator()
next_value = iterator.get_next()
with tf.Session() as sess:
sess.run(init_glob)
for i in range(100):
value = sess.run(next_value)
# Cannot get rid of feed_dict
result = sess.run(dense, feed_dict({x: value})
Method 2:
dataset = tf.data.Dataset.from_tensor_slices(x)
iterator = dataset.make_initializable_iterator()
next_value = iterator.get_next()
with tf.Session() as sess:
sess.run(init_glob)
sess.run(iterator.initializer, feed_dict={x: np_data})
for i in range(100):
value = sess.run(next_value)
# Cannot get rid of feed_dict
result = sess.run(dense, feed_dict({x: value})
https://www.tensorflow.org/guide/performance/overview#input_pipeline
So, how can I "Avoid using feed_dict for all but trivial examples" ? I think I didn't understand the concept of Dataset API
Upvotes: 2
Views: 456
Reputation: 14485
Yes, we need not use feed_dict
if using the dataset api.
Instead we can just apply the dense layer to next_value
each time.
Something like this:
def model(x):
dense = tf.layers.dense(x, 1)
return dense
result_for_this_iteration = model(next_value)
so your full toy example, might look something like this:
def model(x):
dense = tf.layers.dense(x, 10)
return dense
dataset = tf.data.Dataset.from_tensor_slices(np.random.sample((100, 2, 2)))
iterator = dataset.make_one_shot_iterator()
next_value = iterator.get_next()
result_for_this_iteration = model(next_value)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
while(True):
try:
result = sess.run(result_for_this_iteration)
print (result)
except OutOfRangeError:
print ("no more data")
Of course, additional configuration options abound. We can repeat()
so that we dont reach the end of the data but loop over it. We can batch(n)
into batches of size n
. We can map(pre_process)
to apply a pre_process
function to each element, etc.
Upvotes: 2