erobertc
erobertc

Reputation: 644

Replacing tf.placeholder and feed_dict with tf.data API

I have an existing TensorFlow model which used a tf.placeholder for the model input and the feed_dict parameter of tf.Session().run to feed in data. Previously the entire dataset was read into memory and passed in this way.

I want to use a much larger dataset and take advantage of the performance improvements of the tf.data API. I've defined a tf.data.TextLineDataset and one-shot iterator from it, but I'm having a hard time figuring out how to get the data into the model to train it.

At first I tried to just define the feed_dict as a dictionary from the placeholder to iterator.get_next(), but that gave me an error saying the value of a feed cannot be a tf.Tensor object. More digging led me to understand that this is because the object returned by iterator.get_next() is already part of the graph, unlike what you would feed into feed_dict -- and that I shouldn't be trying to use feed_dict at all anyway for performance reasons.

So now I've gotten rid of the input tf.placeholder and replaced it with a parameter to the constructor of the class that defines my model; when constructing the model in my training code, I pass the output of iterator.get_next() to that parameter. This already seems a bit clunky because it breaks separation between the definition of the model and the datasets/training procedure. And I'm now getting an error saying that the Tensor representing (I believe) my model's input must be from the same graph as the Tensor from iterator.get_next().

Am I on the right track with this approach and just doing something wrong with how I set up the graph and the session, or something like that? (The datasets and model are both initialized outside of a session, and the error occurs before I attempt to create one.)

Or am I totally off base with this and need to do something different like use the Estimator API and define everything in an input function?

Here is some code demonstrating a minimal example:

import tensorflow as tf
import numpy as np

class Network:
    def __init__(self, x_in, input_size):
        self.input_size = input_size
        # self.x_in = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size))  # Original
        self.x_in = x_in
        self.output_size = 3

        tf.reset_default_graph()  # This turned out to be the problem

        self.layer = tf.layers.dense(self.x_in, self.output_size, activation=tf.nn.relu)
        self.loss = tf.reduce_sum(tf.square(self.layer - tf.constant(0, dtype=tf.float32, shape=[self.output_size])))

data_array = np.random.standard_normal([4, 10]).astype(np.float32)
dataset = tf.data.Dataset.from_tensor_slices(data_array).batch(2)

model = Network(x_in=dataset.make_one_shot_iterator().get_next(), input_size=dataset.output_shapes[-1])

Upvotes: 20

Views: 6028

Answers (2)

erobertc
erobertc

Reputation: 644

The line tf.reset_default_graph() in the constructor of the model from the original code I was given was causing it. Removing that fixed it.

Upvotes: 3

David Parks
David Parks

Reputation: 32111

It took a bit for me to get my head around too. You're on the right track. The entire Dataset definition is just part of the graph. I generally create it as a different class from my Model class and pass the dataset into the Model class. I specify the Dataset class I want to load on the command line and then load that class dynamically, thereby decoupling the Dataset and the graph modularly.

Notice that you can (and should) name all the tensors in the Dataset, it really helps make things easy to understand as you pass data through the various transformations you'll need.

You can write simple test cases that pull samples from the iterator.get_next() and displays them, you'll have something like sess.run(next_element_tensor), no feed_dict as you've correctly noted.

Once you get your head around it you'll probably start liking the Dataset input pipeline. It forces you to modularize your code well, and it forces it into a structure that's easy to unit test.

Make sure you read the developers guide, there are tons of examples there:

https://www.tensorflow.org/programmers_guide/datasets

Another thing I'll note is how easy it is to work with a train and test dataset with this pipeline. That's important because you often perform data augmentation on the training dataset that you don't perform on the test dataset, from_string_handle allows you to do that and is clearly described in the guide above.

Upvotes: 7

Related Questions