Reputation: 1087
When writing machine learning models, I find myself needing to compute metrics, or run additional forward-passes in callbacks for visualization purposes. In PyTorch, I do this with torch.no_grad()
, and this prevents gradients from being computed and these operations, therefore, do not influence the optimization.
model(x)
is possible. But, it is also possible to say model.predict(x)
, which also seems to invoke the call
. Is there a difference between the two?Upvotes: 6
Views: 6213
Reputation: 505
The tensorflow equivalent would be tf.stop_gradient
Also don't forget, that Keras does not compute gradients when using predict (or just calling the model via __call__
).
Upvotes: 6