Reputation: 463
I am trying to use @tf.function(jit_compile=True)
to create a TF graph as show below. I'm not able to provide a functioning code since it contains a lot of dependencies.
@tf.function(jit_compile=True)
def myfunction(inputs, model):
model = tf.keras.models.load_model()
tf.while()
out2 = model(inputs)
output = tf.gradients(out2,inputs)
return output
Is it possible to include model.predict() in the above code so that I can use a large batch size. Model() call only allows me predict for a fixed batch size and I am looking for prediction on larger batches.
Upvotes: 1
Views: 82