Aseem
Aseem

Reputation: 6787

Nested while loops with scatter update in Tensorflow

variable v1=[[0,0],[0,0]] Tensor t1=[[-1,0],[1,1]]

I want output op=[[1,0],[0,2]]

Logic: If t1==-1 then ignore. Else use t1 value as index for v1 and add 1 to that v1 value.

Python equivalent :

for row in range(len(t1)):
    for col in range(len(t1[row])):
        t1_val=t1[row][col];
        if t1_val!=-1:
            v1[row][t1_val]+=1

I looked through lot of questions on while loops and scatter update but couldn't figure out how to solve the above problem.

Thanks

Upvotes: 0

Views: 119

Answers (1)

giser_yugang
giser_yugang

Reputation: 6176

You may try tf.map_fn:

import tensorflow as tf

v1 = tf.Variable([[0,0],[0,0]], dtype=tf.int32)
t1 = tf.constant([[-1,0],[1,1]], dtype=tf.int32)

result = tf.map_fn(lambda x: x[0]+tf.math.bincount(tf.gather_nd(x[1], tf.where(tf.not_equal(x[1],-1))),minlength=x[0].shape[0])
                   , [v1,t1]
                   , dtype=tf.int32)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(result))

# print
[[1 0]
 [0 2]]

Upvotes: 1

Related Questions