Reputation: 6787
variable v1=[[0,0],[0,0]]
Tensor t1=[[-1,0],[1,1]]
I want output op=[[1,0],[0,2]]
Logic:
If t1==-1
then ignore. Else use t1
value as index for v1
and add 1 to that v1
value.
Python equivalent :
for row in range(len(t1)):
for col in range(len(t1[row])):
t1_val=t1[row][col];
if t1_val!=-1:
v1[row][t1_val]+=1
I looked through lot of questions on while loops and scatter update but couldn't figure out how to solve the above problem.
Thanks
Upvotes: 0
Views: 119
Reputation: 6176
You may try tf.map_fn
:
import tensorflow as tf
v1 = tf.Variable([[0,0],[0,0]], dtype=tf.int32)
t1 = tf.constant([[-1,0],[1,1]], dtype=tf.int32)
result = tf.map_fn(lambda x: x[0]+tf.math.bincount(tf.gather_nd(x[1], tf.where(tf.not_equal(x[1],-1))),minlength=x[0].shape[0])
, [v1,t1]
, dtype=tf.int32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(result))
# print
[[1 0]
[0 2]]
Upvotes: 1