Reputation: 305
I want to update part of the tensor based on some conditions.
I know that TensorFlow tensors are immutable so creating a new tensor would be ok for me.
I tried tensor_scatter_nd_update
method but I couldn't make it work
This is code that I want to replicate in TensorFlow written in NumPy.
import numpy as np
a = np.random.random((1, 3))
b = np.array([[0, 1, 0]])
c = np.zeros_like(a)
mask = b == 1
c[mask] = np.log(a[mask])
Upvotes: 2
Views: 925
Reputation: 683
In TensorFlow, we do not update tensors that are in fact immutable objects. Instead we create new tensors from other tensors like in functional languages.
import tensorflow as tf
a = tf.random.uniform(shape=(1, 3))
b = tf.constant([[0, 1, 0]], dtype=tf.int32)
c = tf.zeros_like(a)
mask = b == 1
c_updated = tf.where(mask, tf.math.log(a), c)
# [[ 0. , -4.175911, 0. ]]```
Upvotes: 4