QNNet
QNNet

Reputation: 83

Finding indices at multiple locations in a Tensor at runtime and replacing it with 0

I wish to assign 0 to multiple locations in a Tensor of size = (n,m) at runtime.

I computed the indices using the where clause in Tensorflow, and called the scatter_nd_update function in order to assign a tf.constant(0) at the newly found multiple locations.

oscvec = tf.where(tf.math.logical_and(sgn2 > 0, sgn1 < 0))
updates = tf.placeholder(tf.float64, [None, None])
oscvec_empty = tf.placeholder(tf.int64, [None])
tf.cond(tf.not_equal(tf.size(oscvec), 0), tf.scatter_nd_update(save_parms, oscvec, tf.constant(0, dtype=tf.float64)),
        tf.scatter_nd_update(save_parms, oscvec_empty, updates))

I will expect tf.where returns an empty tensor when the condition if not satisfied, and a non-empty tensor of indices for save_parms at some point. I decided to create and empty oscvec_empty tensor to deal with cases where the result for tf.where returns an empty tensor. But this does not seem to work....as seen from the following error which is generated when the Tensorflow if-else condition - tf.cond - is used to update save_parms parameter tensor via the tf.scatter_nd_update function:

ValueError: Shape must be at least rank 1 but is rank 0 for 'ScatterNdUpdate' (op: 'ScatterNdUpdate') with input shapes: [55], [?,1], [].

Is there a way to replace values at multiple locations in the save_parms tensor when oscvec is non-empty and not do so, when oscvec is empty? The sgn tensor corresponds to the result of sign function applied on save_parms based on a given criterion.

Upvotes: 1

Views: 168

Answers (1)

giser_yugang
giser_yugang

Reputation: 6176

You can use tf.where() instead of such a complex approach in question.

import tensorflow as tf

vec1 = tf.constant([[ 0.05734377, 0.80147606, -1.2730557 ], [ 0.42826906, 1.1943488 , -0.10129673]])
vec2 = tf.constant([[ 1.5461133 , -0.38455755, -0.79792875], [ 1.5374309 , -1.5657802 , 0.05546811]])

sgn1 = tf.sign(vec1)
sgn2 = tf.sign(vec2)
save_parms = tf.random_normal(shape=sgn1.shape)
oscvec = tf.where(tf.math.logical_and(sgn2 > 0, sgn1 < 0),tf.zeros_like(save_parms),save_parms)

with tf.Session() as sess:
    save_parms_val, oscvec_val = sess.run([save_parms, oscvec])
    print(save_parms_val)
    print(oscvec_val)

[[ 0.75645643 -0.646291   -1.2194813 ]
 [ 1.5204562  -1.0625905   2.9939709 ]]
[[ 0.75645643 -0.646291   -1.2194813 ]
 [ 1.5204562  -1.0625905   0.        ]]

Upvotes: 2

Related Questions