newbie
newbie

Reputation: 463

Tensorflow finding gradients of model output with respect to input on large batch size

I am trying to use @tf.function(jit_compile=True) to create a TF graph as show below. I'm not able to provide a functioning code since it contains a lot of dependencies.


@tf.function(jit_compile=True)
def myfunction(inputs, model):
  model = tf.keras.models.load_model()
  tf.while()    
    out2 = model(inputs)
    output = tf.gradients(out2,inputs)
  return output
  

Is it possible to include model.predict() in the above code so that I can use a large batch size. Model() call only allows me predict for a fixed batch size and I am looking for prediction on larger batches.

Upvotes: 1

Views: 82

Answers (0)

Related Questions