Reputation: 941
I am having trouble understanding the graph argument in the tf.Session()
. I tried looking up at the TensorFlow website :link but couldn't understand much.
I am trying to find out the different between tf.Session()
and tf.Session(graph=some_graph_inserted_here)
.
def predict():
with tf.name_scope("predict"):
with tf.Session() as sess:
saver = tf.train.import_meta_graph("saved_models/testing.meta")
saver.restore(sess, "saved_models/testing")
loaded_graph = tf.get_default_graph()
output_ = loaded_graph.get_tensor_by_name('loss/network/output_layer/BiasAdd:0')
_x = loaded_graph.get_tensor_by_name('x:0')
print sess.run(output_, feed_dict={_x: np.array([12003]).reshape([-1, 1])})
This code gives the following error: ValueError: cannot add op with name hidden_layer1/kernel/Adam as that name is already used
when trying to load the graph at saver = tf.train.import_meta_graph("saved_models/testing.meta")
def predict():
with tf.name_scope("predict"):
loaded_graph = tf.Graph()
with tf.Session(graph=loaded_graph) as sess:
saver = tf.train.import_meta_graph("saved_models/testing.meta")
saver.restore(sess, "saved_models/testing")
output_ = loaded_graph.get_tensor_by_name('loss/network/output_layer/BiasAdd:0')
_x = loaded_graph.get_tensor_by_name('x:0')
print sess.run(output_, feed_dict={_x: np.array([12003]).reshape([-1, 1])})
The codes does not work if I replace loaded_graph = tf.Graph()
with loaded_graph = tf.get_default_graph()
. Why?
Full Code if it helps: (https://gist.github.com/duemaster/f8cf05c0923ebabae476b83e895619ab)
Upvotes: 5
Views: 3818
Reputation: 6328
The TensorFlow Graph
is an object which contains your various tf.Tensor
and tf.Operation
.
When you create these tensors (e.g. using tf.Variable
or tf.constant
) or operations (e.g. tf.matmul
), they will be added to the default graph (look at the graph
member of these object to get the graph they belong to). If you haven't specified anything, it will be the graph you get when calling the tf.get_default_graph
method.
But you could also work with multiple graphes using a context manager:
g = tf.Graph()
with g.as_default():
[your code]
Suppose you created several graphes in your code, you then need to put the graph you and to run as an argument of the tf.Session
method to specify TensorFlow which one to run.
In Code A, you
while in Code B, you
This piece of code makes the Code A work (I reset the default graph to a fresh one, and I removed the predict name_scope
).
def predict():
tf.reset_default_graph()
with tf.Session() as sess:
saver = tf.train.import_meta_graph("saved_models/testing.meta")
saver.restore(sess, "saved_models/testing")
loaded_graph = tf.get_default_graph()
output_ = loaded_graph.get_tensor_by_name('loss/network/output_layer/BiasAdd:0')
_x = loaded_graph.get_tensor_by_name('x:0')
print(sess.run(output_, feed_dict={_x: np.array([12003]).reshape([-1, 1])}))
Upvotes: 5
Reputation: 27042
When you create a Session you're placing a graph into a specified device.
If no graph is specified, the Session constructor tries to build a graph using the default one (that you can get using tf.get_default_graph
).
Your code A doesn't work because in the current session already exists a graph and that graph already contains the same exact node you're trying to import.
Your code B works because you're placing into the Session a new empyt graph (created with tf.Graph()
): when you import the graph definition there's no collision among the existing nodes in the current session (that are 0, because the graph is empty) and the ones you're importing
Upvotes: 1
Reputation: 3159
In Tensorflow, you are constructing graphs. By default, Tensorflow creates a default (sorry for tautology) graph (which you could access using tf.get_default_graph()
). By default, any new Session
object uses this default graph.
In your case, you already have a graph (which is a default one), and you also saved exactly this graph into meta file. Then, you are trying to recover this graph using tf.train.import_meta_graph()
. However, since your session uses a default graph, and you are trying to recover an identical one, you are encountering an error since this operation is trying to duplicate the nodes, which is forbidden.
When you explicitly create a new graph object by calling tf.Graph()
and create a Session
object using this graph (but not the default one) everything is fine since the nodes are created in another graph.
Upvotes: 2
Reputation: 17191
The function tf.train.import_meta_graph("saved_models/testing.meta")
add all the nodes from the meta
file to the current graph
. In the first code, the current graph
is the default_graph
which already has the ops defined, so the error. In the second case, you are loading the nodes to a new graph and so it works fine!.
Upvotes: 1