Reputation: 83
I wish to assign 0 to multiple locations in a Tensor of size = (n,m)
at runtime.
I computed the indices using the where clause in Tensorflow, and called the scatter_nd_update
function in order to assign a tf.constant(0)
at the newly found multiple locations.
oscvec = tf.where(tf.math.logical_and(sgn2 > 0, sgn1 < 0))
updates = tf.placeholder(tf.float64, [None, None])
oscvec_empty = tf.placeholder(tf.int64, [None])
tf.cond(tf.not_equal(tf.size(oscvec), 0), tf.scatter_nd_update(save_parms, oscvec, tf.constant(0, dtype=tf.float64)),
tf.scatter_nd_update(save_parms, oscvec_empty, updates))
I will expect tf.where
returns an empty tensor when the condition if not satisfied, and a non-empty tensor of indices for save_parms
at some point. I decided to create and empty oscvec_empty
tensor to deal with cases where the result for tf.where
returns an empty tensor. But this does not seem to work....as seen from the following error which is generated when the Tensorflow if-else condition - tf.cond
- is used to update save_parms
parameter tensor via the tf.scatter_nd_update
function:
ValueError: Shape must be at least rank 1 but is rank 0 for 'ScatterNdUpdate' (op: 'ScatterNdUpdate') with input shapes: [55], [?,1], [].
Is there a way to replace values at multiple locations in the save_parms
tensor when oscvec is non-empty and not do so, when oscvec is empty? The sgn
tensor corresponds to the result of sign function applied on save_parms
based on a given criterion.
Upvotes: 1
Views: 168
Reputation: 6176
You can use tf.where()
instead of such a complex approach in question.
import tensorflow as tf
vec1 = tf.constant([[ 0.05734377, 0.80147606, -1.2730557 ], [ 0.42826906, 1.1943488 , -0.10129673]])
vec2 = tf.constant([[ 1.5461133 , -0.38455755, -0.79792875], [ 1.5374309 , -1.5657802 , 0.05546811]])
sgn1 = tf.sign(vec1)
sgn2 = tf.sign(vec2)
save_parms = tf.random_normal(shape=sgn1.shape)
oscvec = tf.where(tf.math.logical_and(sgn2 > 0, sgn1 < 0),tf.zeros_like(save_parms),save_parms)
with tf.Session() as sess:
save_parms_val, oscvec_val = sess.run([save_parms, oscvec])
print(save_parms_val)
print(oscvec_val)
[[ 0.75645643 -0.646291 -1.2194813 ]
[ 1.5204562 -1.0625905 2.9939709 ]]
[[ 0.75645643 -0.646291 -1.2194813 ]
[ 1.5204562 -1.0625905 0. ]]
Upvotes: 2