Jensun Ravichandran
Jensun Ravichandran

Reputation: 1087

What is the TensorFlow/Keras equivalent of PyTorch's `no_grad` function?

When writing machine learning models, I find myself needing to compute metrics, or run additional forward-passes in callbacks for visualization purposes. In PyTorch, I do this with torch.no_grad(), and this prevents gradients from being computed and these operations, therefore, do not influence the optimization.

  1. How does this mechanism work in TensorFlow/Keras?
  2. Keras models are callable. So, something like model(x) is possible. But, it is also possible to say model.predict(x), which also seems to invoke the call. Is there a difference between the two?

Upvotes: 6

Views: 6213

Answers (1)

T1Berger
T1Berger

Reputation: 505

The tensorflow equivalent would be tf.stop_gradient

Also don't forget, that Keras does not compute gradients when using predict (or just calling the model via __call__).

Upvotes: 6

Related Questions