wakobu
wakobu

Reputation: 318

Does Tensorflow Dataset API totally get rid of feed_dict argument?

I'm starting using Dataset API to replace the feed_dict system.

However, after creating your Dataset pipeline, how can you feed the Dataset's data to the model without using feed_dict ?

First, I created a one shot iterator. But in this case, you need to use feed_dict to provide the data coming from your iterator to the model.

Secondly, I tried to create my dataset directly from a tf.placeholder and then use a initializable_iterator. But here again, I don't understand how to get rid of feed_dict. In addition, I don't understand what's the purpose of this kind of dataset based on plaeholders.

My very basic model:

x = tf.placeholder(tf.float32, [None, 2])
dense = tf.layers.dense(x, 1)
init_dense = tf.global_variables_initializer()

My data:

np_data = np.random.sample((100,2))

Method 1:

dataset = tf.data.Dataset.from_tensor_slices(np_data)
iterator = dataset.make_one_shot_iterator()
next_value = iterator.get_next()

with tf.Session() as sess:
  sess.run(init_glob)

  for i in range(100):
    value = sess.run(next_value)
    # Cannot get rid of feed_dict
    result = sess.run(dense, feed_dict({x: value})

Method 2:

dataset = tf.data.Dataset.from_tensor_slices(x)
iterator = dataset.make_initializable_iterator()
next_value = iterator.get_next()

with tf.Session() as sess:
  sess.run(init_glob)
  sess.run(iterator.initializer, feed_dict={x: np_data})

  for i in range(100):
    value = sess.run(next_value)
    # Cannot get rid of feed_dict
    result = sess.run(dense, feed_dict({x: value})

https://www.tensorflow.org/guide/performance/overview#input_pipeline

So, how can I "Avoid using feed_dict for all but trivial examples" ? I think I didn't understand the concept of Dataset API

Upvotes: 2

Views: 456

Answers (1)

Stewart_R
Stewart_R

Reputation: 14485

Yes, we need not use feed_dict if using the dataset api.

Instead we can just apply the dense layer to next_value each time.

Something like this:

def model(x):
  dense = tf.layers.dense(x, 1)
  return dense

result_for_this_iteration = model(next_value)

so your full toy example, might look something like this:

def model(x):
  dense = tf.layers.dense(x, 10)
  return dense

dataset = tf.data.Dataset.from_tensor_slices(np.random.sample((100, 2, 2)))
iterator = dataset.make_one_shot_iterator()
next_value = iterator.get_next()

result_for_this_iteration = model(next_value)


with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  while(True):
    try:
      result = sess.run(result_for_this_iteration)
      print (result)
    except OutOfRangeError:
      print ("no more data")

Of course, additional configuration options abound. We can repeat() so that we dont reach the end of the data but loop over it. We can batch(n) into batches of size n. We can map(pre_process) to apply a pre_process function to each element, etc.

Upvotes: 2

Related Questions