Reputation: 327
Suppose I have a 2-D tensor of (batch_size, loss_dim)
and I hope to get the sum of each of the loss dimensions for each data sample, which can be done with tf.reduce_mean(tensor, axis=-1)
.
However, what if there are NaN values in my tensor and I want to simply ignore those NaNs when calculating the sum? Does anyone know how to do that?
PS. I know that we can use tf.boolean_mask
to fiter out the NaNs, but if I simply do tensor = tf.boolean_mask(tensor, tf.logical_not(tf.is_nan(tensor))
, the output will be squashed into a single dimension, which is not what I want.
Thank you so much!
Upvotes: 4
Views: 5054
Reputation: 126184
You can use tf.where()
to replace the NaN values in tensor
with zero while retaining the original shape:
tensor = ...
# Replace all NaN values with 0.0.
tensor_without_nans = tf.where(tf.is_nan(tensor), tf.zeros_like(tensor), tensor)
sum_ignoring_nans = tf.reduce_sum(tensor_without_nans, axis=-1)
Upvotes: 8