bluesummers
bluesummers

Reputation: 12657

Gradually building tf.Graph and executing it

I am trying to gradually build a tf.Graph based on some conditions and run it once after I finished appending.

The code looks like is as follows:

class Model:
    def __init__(self):
        self.graph = tf.Graph()
        ... some code ...

    def build_initial_graph(self):
        with self.graph.as_default():
            X = tf.placeholder(tf.float32, shape=some_shape)
            ... some code ...

    def add_to_existing_graph(self):
        with self.graph.as_default():
            ... some code adding more ops to the graph ...

    def transform(self, data):
        with tf.Session(graph=self.graph) as session:
             y = session.run(Y, feed_dict={X: data})
        return y

Calling the methods will look something like this

model = Model()
model.build_initial_graph()
model.add_to_existing_graph()
model.add_to_existing_graph()
result = model.transform(data)

So, two questions

  1. Is this way legit of adding operations to existing graph? using the same graph objects in different places or will it override the old one?
  2. On the transform method, X in the feed_dict is not recognized when the code is ran, what will be the right way to achieve that?

Upvotes: 0

Views: 72

Answers (1)

pfm
pfm

Reputation: 6338

Q1: This is of course a legit way to build your model but it's more a matter of opinion. I would only suggest to store your tensors as attribute (see answer to Q2.) self.X=....

You could have a look to this very nice post on how to structure your TensorFlow model in an Objected Oriented way.

Q2: The reason is really simple and lives in the fact that the variable X is not in the scope of your transform method.
If you do the following everything would work fine:

def build_initial_graph(self):
    with self.graph.as_default():
        self.X = tf.placeholder(tf.float32, shape=some_shape)
        ... some code ...

def transform(self, data):
    with tf.Session(graph=self.graph) as session:
         return session.run(self.Y, feed_dict={self.X: data})

To be more detailed, in TensorFlow, all the Tensors or Operations you define (e.g. tf.placeholder or tf.matmul) are defined in the tf.Graph() youre working on. You might want to store them in Python variable, as you did by doingX = tf.placeholder` but that's not mandatory.

If you want to access after on to one of the Tensor you defined, you can either

  • use the Python variable (it was your attempt except, that the variable X was not in the scope of the method) or,
  • retrieve them directly from the graph (you need to know it's name), using the tf.get_variable method).

Upvotes: 2

Related Questions