George
George

Reputation: 944

TensorFlow 2.0 scatter add

I would like to implement the following design in TensorFlow 2.0.

Given a memory tensor of shape [a, b, c],
an indices tensor of shape [a, 1],
and an updates tensor of shape [a, c],

I want to increment memory at the positions indicated by indices with the values in updates.

tf.tensor_scatter_nd_add does not seem to work:

tf.tensor_scatter_nd_add(memory, indices, updates) returns {InvalidArgumentError}Inner dimensions of output shape must match inner dimensions of updates shape. Output: [a,b,c] updates: [a,c] [Op:TensorScatterAdd].

Is it really necessary for updates to have as many inner dimensions as memory ? In my logic, memory[indices] (as a pseudocode) should already be a tensor of shape [a, c]. Furthermore, the shape of tf.gather_nd(params=memory, indices=indices, batch_dims=1) is already [a, c].

Could you please recommend an alternative ?

Thanks.

Upvotes: 2

Views: 763

Answers (1)

javidcf
javidcf

Reputation: 59731

I think what you want is this:

import tensorflow as tf

a, b, c = 3, 4, 5
memory = tf.ones([a, b, c])
indices = tf.constant([[2], [0], [3]])
updates = 10 * tf.reshape(tf.range(a * c, dtype=memory.dtype), [a, c])
print(updates.numpy())
# [[  0.  10.  20.  30.  40.]
#  [ 50.  60.  70.  80.  90.]
#  [100. 110. 120. 130. 140.]]

# Make indices for first dimension
ind_a = tf.range(tf.shape(indices, out_type=indices.dtype)[0])
# Make full indices
indices_2 = tf.concat([tf.expand_dims(ind_a, 1), indices], axis=1)
# Scatter add
out = tf.tensor_scatter_nd_add(memory, indices_2, updates)
print(out.numpy())
# [[[  1.   1.   1.   1.   1.]
#   [  1.   1.   1.   1.   1.]
#   [  1.  11.  21.  31.  41.]
#   [  1.   1.   1.   1.   1.]]
# 
#  [[ 51.  61.  71.  81.  91.]
#   [  1.   1.   1.   1.   1.]
#   [  1.   1.   1.   1.   1.]
#   [  1.   1.   1.   1.   1.]]
# 
#  [[  1.   1.   1.   1.   1.]
#   [  1.   1.   1.   1.   1.]
#   [  1.   1.   1.   1.   1.]
#   [101. 111. 121. 131. 141.]]]

Upvotes: 2

Related Questions