Karsten Schreiblehner
Karsten Schreiblehner

Reputation: 27

Delete Tensorflow model without kernel restart

I wrote a small program (python) to create and train a model (LSTM-VAE) using different parameters. Unfortunately, there are different @tf.function annotations, which let me train one model, the next one fails by ValueError: tf.function-decorated function tried to create variables on non-first call.

@tf.function
def compute_loss(model,x):

    mean, logvar = model.encode(x)
    z = model.reparameterize(mean,logvar)
    x_pred = model.decode(z)

    cross_ent = tf.keras.metrics.MSE(y_pred=x_pred, y_true=x)

    logpx_z = -tf.reduce_sum(cross_ent, axis=[1])
    logpz = log_normal_pdf(z, 0., 0.)
    logqz_x = log_normal_pdf(z, mean, logvar)

    return -tf.reduce_mean(logpx_z + logpz - logqz_x)

@tf.function
def compute_apply_gradients(model, x, optimizer):
    with tf.GradientTape() as tape:
        loss = compute_loss(model, x)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

To be able to rerun again, I always have to restart the kernel.

Is there a way to delete the underlying tensorflow graph manually? I think this would solve this, since I am (at the moment) only interested in the results.

Best regards and thanks for your help!

Upvotes: 0

Views: 325

Answers (2)

velociraptor11
velociraptor11

Reputation: 604

Without looking at your full training code, I can only assume that your tf.Session is running on the default graph. Hence, to clear the graph you could use

tf.reset_default_graph()

Tensorflow Core r2.0

tf.compat.v1.reset_default_graph

Upvotes: 1

Karsten Schreiblehner
Karsten Schreiblehner

Reputation: 27

Ok, I had another look. Attaching the functions, that are marked with

@tf.function

directly to the respective model-class, is a cleaner way to solve this without relying on compat functions.

Upvotes: 0

Related Questions