Yee Liu
Yee Liu

Reputation: 1437

Index operation in TensorFlow

When I was doing batch labeling for some data, I have a variable for recording all calculation results:

p_all = tf.Variable(tf.zeros([batch_num, batch_size]), name = "probability");

In the calculation, I have a loop to deal with each batch:

for i in range(batch_num):
    feed = {x: testDS.test.next_batch(batch_size)}
    sess.run(p_each_batch, feed_dict=feed)

How can I copy value of p_each_bach into p_all ?

To be more clear, I want something like:

... ...
p_all[batch_index,:] = p_each_batch
for i in range(batch_num):
    feed = {x: testDS.test.next_batch(batch_size), batch_index: i}
    sess.run(p_all, feed_dict=feed)

How can I make those code actually work ?

Upvotes: 0

Views: 450

Answers (1)

mrry
mrry

Reputation: 126184

Since p_all is a tf.Variable, you can use the tf.scatter_update() op to update individual rows in each batch:

# Equivalent to `p_all[batch_index, :] = p_each_batch`
update_op = tf.scatter_update(p_all,
                              tf.expand_dims(batch_index, 0),
                              tf.expand_dims(p_each_batch, 0)) 

for i in range(batch_num):
    feed = {x: testDS.test.next_batch(batch_size), batch_index: i}
    sess.run(update_op, feed_dict=feed)

Upvotes: 1

Related Questions