Reputation: 27
I wrote a small program (python) to create and train a model (LSTM-VAE) using different parameters. Unfortunately, there are different @tf.function
annotations, which let me train one model, the next one fails by ValueError: tf.function-decorated function tried to create variables on non-first call.
@tf.function
def compute_loss(model,x):
mean, logvar = model.encode(x)
z = model.reparameterize(mean,logvar)
x_pred = model.decode(z)
cross_ent = tf.keras.metrics.MSE(y_pred=x_pred, y_true=x)
logpx_z = -tf.reduce_sum(cross_ent, axis=[1])
logpz = log_normal_pdf(z, 0., 0.)
logqz_x = log_normal_pdf(z, mean, logvar)
return -tf.reduce_mean(logpx_z + logpz - logqz_x)
@tf.function
def compute_apply_gradients(model, x, optimizer):
with tf.GradientTape() as tape:
loss = compute_loss(model, x)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
To be able to rerun again, I always have to restart the kernel.
Is there a way to delete the underlying tensorflow graph manually? I think this would solve this, since I am (at the moment) only interested in the results.
Best regards and thanks for your help!
Upvotes: 0
Views: 325
Reputation: 604
Without looking at your full training code, I can only assume that your tf.Session is running on the default graph. Hence, to clear the graph you could use
tf.reset_default_graph()
Tensorflow Core r2.0
tf.compat.v1.reset_default_graph
Upvotes: 1
Reputation: 27
Ok, I had another look. Attaching the functions, that are marked with
@tf.function
directly to the respective model-class, is a cleaner way to solve this without relying on compat functions.
Upvotes: 0