Roby
Roby

Reputation: 144

How to implement Numpy where index in TensorFlow?

I have the following operations which uses numpy.where:

    mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
    index = np.array([[1,0,0],[0,1,0],[0,0,1]])
    mat[np.where(index>0)] = 100
    print(mat)

How to implement the equivalent in TensorFlow?

mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
index = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
tf_mat = tf.constant(mat)
tf_index = tf.constant(index)
indi = tf.where(tf_index>0)
tf_mat[indi] = -1   <===== not allowed 

Upvotes: 4

Views: 6125

Answers (2)

Kaihong Zhang
Kaihong Zhang

Reputation: 419

You can get indexes by tf.where, then you can either run the index, or use tf.gather to collect data from the origin array, or use tf.scatter_update to update origin data, tf.scatter_nd_update for multi-dimension update.

mat = tf.Variable([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=tf.int32)
index = tf.Variable([[1,0,0],[0,1,0],[0,0,1]])
idx = tf.where(index>0)
tf.scatter_nd_update(mat, idx, /*values you want*/)

note that update values should be the same first dimension size with idx.

see https://www.tensorflow.org/api_guides/python

Upvotes: 2

javidcf
javidcf

Reputation: 59681

Assuming that what you want is to create a new tensor with some replaced elements, and not update a variable, you could do something like this:

import numpy as np
import tensorflow as tf

mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
index = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
tf_mat = tf.constant(mat)
tf_index = tf.constant(index)
tf_mat = tf.where(tf_index > 0, -tf.ones_like(tf_mat), tf_mat)
with tf.Session() as sess:
    print(sess.run(tf_mat))

Output:

[[-1  2  3]
 [ 4 -1  6]
 [ 7  8 -1]]

Upvotes: 7

Related Questions