Reputation: 1778
Given a:
tf.tensor A
with shape N*M
, (A = tf.zeros((N, M))
)
tf.tensor
indices
with shape N*k
(k<=M). Each ith
row contains some indices of tensor A
tf.tensor upates
with shape N*K
. Each ith
row contains values to be used to updates tensor A
Goal: update the elements of A
where their indices exist in indices
, with values in updates
Use tf.scatter_nd in a loop
result = []
for idx in range(N):
index = tf.reshape(indices[idx], (-1, 1))
updates = tf.convert_to_tensor(updates[idx])
scatter = tf.scatter_nd(index, updates, shape=tf.constant([M]))
target.append(scatter)
result = tf.stack(result, axis=0)
This loop is obviously working for N
being small.
Question: How to vectorize this to run faster.
Upvotes: 1
Views: 44
Reputation: 59731
If the first tensor A
is always made of zeros, you can do that with one call to tf.scatter_nd
:
import tensorflow as tf
indices = ... # shape: (n, k)
updates = ... # shape: (n, k)
s = tf.shape(indices, out_type=indices.dtype)
n = s[0]
k = s[1]
idx_row = tf.tile(tf.expand_dims(tf.range(n), 1), (1, k))
idx_full = tf.stack([idx_row , indices], axis=-1)
result = tf.scatter_nd(idx_full, updates, [n, m])
If the initial A
contains something else, you would do essentially the same but using tf.tensor_scatter_nd_update
instead:
A = ... # shape: (n, m)
result = tf.tensor_scatter_nd_update(A, idx_full, updates)
Upvotes: 1