Reputation:
I have realized that there is some funky stuff going on with the way Tensorflow seems to be managing graphs.
Since building (and rebuilding) models is so tedious, I decided to wrap my custom model in a class so I could easily re-instantiate it elsewhere.
When I was training and testing the code (in the original place) it would work fine, however in the code where I loaded the graph's variables I would get all sorts of weird errors - variable redefinitions and everything else. This (from my last question about a similar thing) was the hint that everything was being called twice.
After doing a TON of tracing, it came down to the way I was using the loaded code. It was being used from within a class that had a structure like so
class MyModelUser(object):
def forecast(self):
# .. build the model in the same way as in the training code
# load the model checkpoint
# call the "predict" function on the model
# manipulate the prediction and return it
And then in some code that uses MyModelUser
I had
def test_the_model(self):
model_user = MyModelUser()
print(model_user.forecast()) # 1
print(model_user.forecast()) # 2
and I (obviously) expected to see two forecasts when this was called. Instead, the first forecast was called and worked as expected, but the second call threw a TON of variable reuse ValueError an example of one of these was:
ValueError: Variable weight_def/weights already exists, disallowed. Did you mean to set reuse=True in VarScope?
I managed to quell the errors by adding a series of try/except blocks that used get_variable
to create the variable, and then on exception, called reuse_variables
on the scope and then get_variable
without anything but the name. This brought on a new set of nasty errors, one of which was:
tensorflow.python.framework.errors.NotFoundError: Tensor name "weight_def/weights/Adam_1" not found in checkpoint files
On a whim I said "what if I move the modeling building code to __init__
so its only built once?"
My new model user:
class MyModelUser(object):
def __init__(self):
# ... build the model in the same way as in the training code
# load the model checkpoint
def forecast(self):
# call the "predict" function on the model
# manipulate the prediction and return it
and now:
def test_the_model(self):
model_user = MyModelUser()
print(model_user.forecast()) # 1
print(model_user.forecast()) # 2
Works as expected, printing two forecasts with no errors. This leads me to believe I can also get rid of the variable reuse stuff.
My question is this:
Why did this fix it? In theory, the graph should be reinstanced every single time in the original predict method, so it shouldn't be creating more than one graph. Does Tensorflow persist the graph even after the function completes? Is this why moving the creation code to __init__
worked? This has left me hopelessly confused.
Upvotes: 2
Views: 1081
Reputation: 126154
By default, TensorFlow uses a single global tf.Graph
instance that is created when you first call a TensorFlow API. If you do not create a tf.Graph
explicitly, all operations, tensors, and variables will be created in that default instance. This means that each call in your code to model_user.forecast()
will be adding operations to the same global graph, which is somewhat wasteful.
There are (at least) two possible courses of action here:
The ideal action would be to restructure your code so that MyModelUser.__init__()
constructs an entire tf.Graph
with all of the operations needed to perform forecasting, and MyModelUser.forecast()
simply performs sess.run()
calls on the existing graph. Ideally, you would only create a single tf.Session
as well, because TensorFlow caches information about the graph in the session, and the execution would be more efficient.
The less invasive—but probably less efficient—change would be to create a new tf.Graph
for every call to MyModelUser.forecast()
. It's unclear from the question how much state is created in the MyModelUser.__init__()
method, but you could do something like the following to put the two calls in different graphs:
def test_the_model(self):
with tf.Graph(): # Create a local graph
model_user_1 = MyModelUser()
print(model_user_1.forecast())
with tf.Graph(): # Create another local graph
model_user_2 = MyModelUser()
print(model_user_2.forecast())
Upvotes: 4
Reputation: 17131
TF has a default graph that new operations etc get added to. When you call your function twice, you will add the same things twice to the same graph. So, either build the graph once and evaluate it multiple times (as you have done, which is also the "normal" approach), or, if you want to change things, you can use reset_default_graph https://www.tensorflow.org/versions/r0.11/api_docs/python/framework.html#reset_default_graph to reset the graph in order to have a fresh state.
Upvotes: 0