Reputation: 12657
I am trying to gradually build a tf.Graph
based on some conditions and run it once after I finished appending.
The code looks like is as follows:
class Model:
def __init__(self):
self.graph = tf.Graph()
... some code ...
def build_initial_graph(self):
with self.graph.as_default():
X = tf.placeholder(tf.float32, shape=some_shape)
... some code ...
def add_to_existing_graph(self):
with self.graph.as_default():
... some code adding more ops to the graph ...
def transform(self, data):
with tf.Session(graph=self.graph) as session:
y = session.run(Y, feed_dict={X: data})
return y
Calling the methods will look something like this
model = Model()
model.build_initial_graph()
model.add_to_existing_graph()
model.add_to_existing_graph()
result = model.transform(data)
So, two questions
X
in the feed_dict
is not recognized when the code is ran, what will be the right way to achieve that?Upvotes: 0
Views: 72
Reputation: 6338
Q1: This is of course a legit way to build your model but it's more a matter of opinion. I would only suggest to store your tensors as attribute (see answer to Q2.) self.X=...
.
You could have a look to this very nice post on how to structure your TensorFlow model in an Objected Oriented way.
Q2: The reason is really simple and lives in the fact that the variable X
is not in the scope of your transform
method.
If you do the following everything would work fine:
def build_initial_graph(self):
with self.graph.as_default():
self.X = tf.placeholder(tf.float32, shape=some_shape)
... some code ...
def transform(self, data):
with tf.Session(graph=self.graph) as session:
return session.run(self.Y, feed_dict={self.X: data})
To be more detailed, in TensorFlow, all the Tensors or Operations you define (e.g. tf.placeholder
or tf.matmul
) are defined in the tf.Graph()
youre working on. You might want to store them in Python variable, as you did by doing
X = tf.placeholder` but that's not mandatory.
If you want to access after on to one of the Tensor you defined, you can either
X
was not in the scope of the method) or,tf.get_variable
method).Upvotes: 2