Reputation: 2133
Why is the gradient of tf.where(x > 1, tf.math.log(x), 0)
nan
when x
is 0.0
, but not when it's -1
or 1
?
Minimal example:
import tensorflow as tf
x = tf.constant([-1, 0, 1], tf.float32)
with tf.GradientTape() as g:
g.watch(x)
y = tf.where(x > 1, tf.math.log(x), 0)
print(y)
dy_dx = g.gradient(y, x)
print(dy_dx)
Output:
tf.Tensor([0. 0. 0.], shape=(3,), dtype=float32)
tf.Tensor([-0. nan 0.], shape=(3,), dtype=float32)
Upvotes: 0
Views: 35