Reputation: 57
The following is how it works in Numpy
import numpy as np
vals_for_fives = [12, 18, 22, 33]
arr = np.array([5, 2, 3, 5, 5, 5])
arr[arr == 5] = vals_for_fives # It is guaranteed that length of vals_for_fives is equal to the number of fives in arr
# now the value of arr is [12, 2, 3, 18, 22, 33]
For broadcastable or constant assignment we can use where()
and assign()
in Tensorflow. How can we achieve the above scenario in TF?
Upvotes: 0
Views: 134
Reputation: 1971
tf.experimental.numpy.where is a thing in tensorflow v2.5
.
But for now you could do this:
First find the positions of the 5
's:
arr = np.array([5, 2, 3, 5, 5, 5])
where = tf.where(arr==5)
where = tf.cast(where, tf.int32)
print(where)
# <tf.Tensor: id=91, shape=(4, 1), dtype=int32, numpy=
array([[0],
[3],
[4],
[5]])>
Then use scatter_nd
to "replace" elements by index:
tf.scatter_nd(where, tf.constant([12,18,22,23]), tf.constant([5]))
# <tf.Tensor: id=94, shape=(5,), dtype=int32, numpy=array([12, 0, 0, 18
, 22])>
Do a similar thing for the entries that were not 5
to find the missing tensor:
tf.scatter_nd(tf.constant([[1], [2]]), tf.constant([2,3]), tf.constant([5]))
# <tf.Tensor: id=98, shape=(5,), dtype=int32, numpy=array([0, 2, 3, 0, 0])>
Then sum the two tensors to get:
<tf.Tensor: id=113, shape=(5,), dtype=int32, numpy=array([12, 2, 3, 1, 8, 22])>
Upvotes: 2