LearnToGrow
LearnToGrow

Reputation: 1778

How to vectorize modifying elements in a tensor

Given a:

tf.tensor A with shape N*M, (A = tf.zeros((N, M)))
tf.tensor indices with shape N*k (k<=M). Each ith row contains some indices of tensor A
tf.tensor upates with shape N*K . Each ith row contains values to be used to updates tensor A

Goal: update the elements of A where their indices exist in indices, with values in updates

Use tf.scatter_nd in a loop

result = []
for idx in range(N):
    index = tf.reshape(indices[idx], (-1, 1))
    updates = tf.convert_to_tensor(updates[idx])
    scatter = tf.scatter_nd(index, updates, shape=tf.constant([M]))
    target.append(scatter)
result = tf.stack(result, axis=0)

This loop is obviously working for N being small.

Question: How to vectorize this to run faster.

Upvotes: 1

Views: 44

Answers (1)

javidcf
javidcf

Reputation: 59731

If the first tensor A is always made of zeros, you can do that with one call to tf.scatter_nd:

import tensorflow as tf

indices = ... # shape: (n, k)
updates = ... # shape: (n, k)
s = tf.shape(indices, out_type=indices.dtype)
n = s[0]
k = s[1]
idx_row = tf.tile(tf.expand_dims(tf.range(n), 1), (1, k))
idx_full = tf.stack([idx_row , indices], axis=-1)
result = tf.scatter_nd(idx_full, updates, [n, m])

If the initial A contains something else, you would do essentially the same but using tf.tensor_scatter_nd_update instead:

A = ... # shape: (n, m)
result = tf.tensor_scatter_nd_update(A, idx_full, updates)

Upvotes: 1

Related Questions