Bosen
Bosen

Reputation: 941

What does graph argument in tf.Session() do?

I am having trouble understanding the graph argument in the tf.Session(). I tried looking up at the TensorFlow website :link but couldn't understand much.

I am trying to find out the different between tf.Session() and tf.Session(graph=some_graph_inserted_here).

Question Context

Code A (Not Working):

def predict():
    with tf.name_scope("predict"):
        with tf.Session() as sess:
            saver = tf.train.import_meta_graph("saved_models/testing.meta")
            saver.restore(sess, "saved_models/testing")
            loaded_graph = tf.get_default_graph()
            output_ = loaded_graph.get_tensor_by_name('loss/network/output_layer/BiasAdd:0')
            _x = loaded_graph.get_tensor_by_name('x:0')
            print sess.run(output_, feed_dict={_x: np.array([12003]).reshape([-1, 1])})

This code gives the following error: ValueError: cannot add op with name hidden_layer1/kernel/Adam as that name is already used when trying to load the graph at saver = tf.train.import_meta_graph("saved_models/testing.meta")

Code B (Working):

def predict():
    with tf.name_scope("predict"):
        loaded_graph = tf.Graph()
        with tf.Session(graph=loaded_graph) as sess:
            saver = tf.train.import_meta_graph("saved_models/testing.meta")
            saver.restore(sess, "saved_models/testing")
            output_ = loaded_graph.get_tensor_by_name('loss/network/output_layer/BiasAdd:0')
            _x = loaded_graph.get_tensor_by_name('x:0')
            print sess.run(output_, feed_dict={_x: np.array([12003]).reshape([-1, 1])})

The codes does not work if I replace loaded_graph = tf.Graph() with loaded_graph = tf.get_default_graph(). Why?

Full Code if it helps: (https://gist.github.com/duemaster/f8cf05c0923ebabae476b83e895619ab)

Upvotes: 5

Views: 3818

Answers (4)

pfm
pfm

Reputation: 6328

The TensorFlow Graph is an object which contains your various tf.Tensor and tf.Operation.

When you create these tensors (e.g. using tf.Variable or tf.constant) or operations (e.g. tf.matmul), they will be added to the default graph (look at the graph member of these object to get the graph they belong to). If you haven't specified anything, it will be the graph you get when calling the tf.get_default_graph method.

But you could also work with multiple graphes using a context manager:

g = tf.Graph()
with g.as_default():
    [your code]

Suppose you created several graphes in your code, you then need to put the graph you and to run as an argument of the tf.Session method to specify TensorFlow which one to run.

In Code A, you

  • work with the default graph,
  • try to import the meta graph into it (which fails because it already contains some of the nodes) and,
  • would restore the model into it,

while in Code B, you

  • create a fresh new graph,
  • import the meta graph into it (which succeeds because it's an empty graph) and
  • restore it.

Useful link:

tf.Graph API

Edit:

This piece of code makes the Code A work (I reset the default graph to a fresh one, and I removed the predict name_scope).

def predict():
    tf.reset_default_graph()
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph("saved_models/testing.meta")
        saver.restore(sess, "saved_models/testing")
        loaded_graph = tf.get_default_graph()
        output_ = loaded_graph.get_tensor_by_name('loss/network/output_layer/BiasAdd:0')
        _x = loaded_graph.get_tensor_by_name('x:0')
        print(sess.run(output_, feed_dict={_x: np.array([12003]).reshape([-1, 1])}))

Upvotes: 5

nessuno
nessuno

Reputation: 27042

When you create a Session you're placing a graph into a specified device.

If no graph is specified, the Session constructor tries to build a graph using the default one (that you can get using tf.get_default_graph).

Your code A doesn't work because in the current session already exists a graph and that graph already contains the same exact node you're trying to import.

Your code B works because you're placing into the Session a new empyt graph (created with tf.Graph()): when you import the graph definition there's no collision among the existing nodes in the current session (that are 0, because the graph is empty) and the ones you're importing

Upvotes: 1

Dmytro Danevskyi
Dmytro Danevskyi

Reputation: 3159

In Tensorflow, you are constructing graphs. By default, Tensorflow creates a default (sorry for tautology) graph (which you could access using tf.get_default_graph()). By default, any new Session object uses this default graph.

In your case, you already have a graph (which is a default one), and you also saved exactly this graph into meta file. Then, you are trying to recover this graph using tf.train.import_meta_graph(). However, since your session uses a default graph, and you are trying to recover an identical one, you are encountering an error since this operation is trying to duplicate the nodes, which is forbidden.

When you explicitly create a new graph object by calling tf.Graph() and create a Session object using this graph (but not the default one) everything is fine since the nodes are created in another graph.

Upvotes: 2

Vijay Mariappan
Vijay Mariappan

Reputation: 17191

The function tf.train.import_meta_graph("saved_models/testing.meta") add all the nodes from the meta file to the current graph. In the first code, the current graph is the default_graph which already has the ops defined, so the error. In the second case, you are loading the nodes to a new graph and so it works fine!.

Upvotes: 1

Related Questions