xpilot
xpilot

Reputation: 1019

Reloading tensorflow model

I have two separate tensorflow processes, one which is training a model and writing out graph_defs with tensorflow.python.client.graph_util.convert_variables_to_constants, and another which is reading the graph_def with tensorflow.import_graph_def. I would like the second process to periodically reload the graph_def as it gets updated by the first process. Unfortunately, it appears that every subsequent time I read the graph_def the old one is still used, even if I close the current session and create a new one. I have also tried wrapping the import_graph_def call with sess.graph.as_default(), to no avail. Here is my current graph_def loading code:

if self.sess is not None:
    self.sess.close()
self.sess = tf.Session()

graph_def = tf.GraphDef()
with open(self.graph_path, 'rb') as f:
    graph_def.ParseFromString(f.read())
with self.sess.graph.as_default():
    tf.import_graph_def(graph_def, name='')

Upvotes: 1

Views: 2357

Answers (1)

mrry
mrry

Reputation: 126154

The problem here is that, when you create a tf.Session with no arguments, it uses the current default graph. Assuming you don't create a tf.Graph anywhere else in your code, you get the global default graph that is created when the process starts, and this is shared between all of the sessions. As a result, with self.sess.graph.as_default(): has no effect.

It's hard to recommend a new structure from the snippet you showed in the question—in particular, I've no idea about how you created the previous graph, or what the class structure is—but one possibility would be to replace the self.sess = tf.Session() with the following:

self.sess = tf.Session(graph=tf.Graph())  # Creates a new graph for the session.

Now the with self.sess.graph.as_default(): will use the graph that was created for the session, and your program should have the intended effect.

A somewhat preferable (to me, at least) alternative would be to build the graph explicitly:

with tf.Graph().as_default() as imported_graph:
    tf.import_graph_def(graph_def, ...)

sess = tf.Session(graph=imported_graph)

Upvotes: 5

Related Questions