Vijay Giri
Vijay Giri

Reputation: 57

How to do numpy like conditional assignment in Tensorflow

The following is how it works in Numpy

import numpy as np

vals_for_fives = [12, 18, 22, 33]
arr = np.array([5, 2, 3, 5, 5, 5])
arr[arr == 5] = vals_for_fives  # It is guaranteed that length of vals_for_fives is equal to the number of fives in arr

# now the value of arr is [12, 2, 3, 18, 22, 33]

For broadcastable or constant assignment we can use where() and assign() in Tensorflow. How can we achieve the above scenario in TF?

Upvotes: 0

Views: 134

Answers (1)

SuperCiocia
SuperCiocia

Reputation: 1971

tf.experimental.numpy.where is a thing in tensorflow v2.5.

But for now you could do this:

First find the positions of the 5's:

arr = np.array([5, 2, 3, 5, 5, 5])
where = tf.where(arr==5)
where = tf.cast(where, tf.int32)
print(where)
# <tf.Tensor: id=91, shape=(4, 1), dtype=int32, numpy=
array([[0],
       [3],
       [4],
       [5]])>

Then use scatter_nd to "replace" elements by index:

tf.scatter_nd(where, tf.constant([12,18,22,23]), tf.constant([5]))
# <tf.Tensor: id=94, shape=(5,), dtype=int32, numpy=array([12,  0,  0, 18
, 22])>

Do a similar thing for the entries that were not 5 to find the missing tensor:

tf.scatter_nd(tf.constant([[1], [2]]), tf.constant([2,3]), tf.constant([5]))
# <tf.Tensor: id=98, shape=(5,), dtype=int32, numpy=array([0, 2, 3, 0, 0])>

Then sum the two tensors to get:

<tf.Tensor: id=113, shape=(5,), dtype=int32, numpy=array([12,  2,  3, 1, 8, 22])>

Upvotes: 2

Related Questions