user5779223
user5779223

Reputation: 1490

Proper usage of `tf.scatter_nd` in tensorflow-r1.2

Given indices with shape [batch_size, sequence_len], updates with shape [batch_size, sequence_len, sampled_size], to_shape with shape [batch_size, sequence_len, vocab_size], where vocab_size >> sampled_size, I'd like to use tf.scatter to map the updates to a huge tensor with to_shape, such that to_shape[bs, indices[bs, sz]] = updates[bs, sz]. That is, I'd like to map the updates to to_shape row by row. Please note that sequence_len and sampled_size are scalar tensors, while others are fixed. I tried to do the following:

new_tensor = tf.scatter_nd(tf.expand_dims(indices, axis=2), updates, to_shape)

But I got an error:

ValueError: The inner 2 dimension of output.shape=[?,?,?] must match the inner 1 dimension of updates.shape=[80,50,?]: Shapes must be equal rank, but are 2 and 1 for .... with input shapes: [80, 50, 1], [80, 50,?], [3]

Could you please tell me how to use scatter_nd properly? Thanks in advance!

Upvotes: 5

Views: 7912

Answers (2)

Rohan Mukherjee
Rohan Mukherjee

Reputation: 65

I think you might be looking for this.

def permute_batched_tensor(batched_x, batched_perm_ids):
    indices = tf.tile(tf.expand_dims(batched_perm_ids, 2), [1,1,batched_x.shape[2]])

    # Create additional indices
    i1, i2 = tf.meshgrid(tf.range(batched_x.shape[0]),
                     tf.range(batched_x.shape[2]), indexing="ij")
    i1 = tf.tile(i1[:, tf.newaxis, :], [1, batched_x.shape[1], 1])
    i2 = tf.tile(i2[:, tf.newaxis, :], [1, batched_x.shape[1], 1])
    # Create final indices
    idx = tf.stack([i1, indices, i2], axis=-1)
    temp = tf.scatter_nd(idx, batched_x, batched_x.shape)
    return temp

Upvotes: 0

javidcf
javidcf

Reputation: 59721

So assuming you have:

  • A tensor updates with shape [batch_size, sequence_len, sampled_size].
  • A tensor indices with shape [batch_size, sequence_len, sampled_size].

Then you do:

import tensorflow as tf

# Create updates and indices...

# Create additional indices
i1, i2 = tf.meshgrid(tf.range(batch_size),
                     tf.range(sequence_len), indexing="ij")
i1 = tf.tile(i1[:, :, tf.newaxis], [1, 1, sampled_size])
i2 = tf.tile(i2[:, :, tf.newaxis], [1, 1, sampled_size])
# Create final indices
idx = tf.stack([i1, i2, indices], axis=-1)
# Output shape
to_shape = [batch_size, sequence_len, vocab_size]
# Get scattered tensor
output = tf.scatter_nd(idx, updates, to_shape)

tf.scatter_nd takes an indices tensor, an updates tensor and some shape. updates is the original tensor, and the shape is just the desired output shape, so [batch_size, sequence_len, vocab_size]. Now, indices is more complicated. Since your output has 3 dimensions (rank 3), for each of the elements in updates you need 3 indices to determine where in the output each element is going to be placed. So the shape of the indices parameter should be the same as updates with an additional dimension of size 3. In this case, we want the first to dimensions to be the same, but we still have to specify the 3 indices. So we use tf.meshgrid to generate the indices that we need and we tile them along the third dimension (the first and second index for each element vector in the last dimension of updates is the same). Finally, we stack these indices with the previously created mapping indices and we have our full 3-dimensional indices.

Upvotes: 6

Related Questions