Reputation: 95
Both tf.scatter_add and tf.scatter_nd allow indices
to be a matrix. It is clear from the documentation of tf.scatter_nd that the last dimension of indices
contains values that are used to index a tensor of shape shape
. The other dimensions of indices
define the number of elements/slices to be scattered. Suppose updates
has a rank N
. First k
dimensions of indices
(except the last dimension) should match with first k
dimensions of updates
. The last (N-k)
dimensions of updates
should match with the last (N-k)
dimensions of shape
.
This implies that tf.scatter_nd
can be used to perform an N
-dimensional scatter. However, tf.scatter_add
also takes matrices as indices
. But, its not clear which dimensions of indices
correspond to the number of scatters to be performed and how do these dimensions align with updates
. Can someone provide a clear explanation possibly with examples?
Upvotes: 2
Views: 1103
Reputation: 980
@shaunshd , I finally fully understand the 3 tensors relationship in tf.scatter_nd_*() arguments, especially when the indices have multi-demensions. e.g: indices = tf.constant([[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [3,3,2]], dtype=tf.int32)
Please don't expect tf.rank(indices)>2, tf.rank(indices)==2 is permanently true;
The following is my test codes to show more complex test case than the examples provided in tensroflow's official website:
def testScatterNDUpdate(self):
ref = tf.Variable(np.zeros(shape=[4, 4, 4], dtype=np.float32))
indices = tf.constant([[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [3,3,2]], dtype=tf.int32)
updates = tf.constant([1,2,3,4,5], dtype=tf.float32)
#shape = (4,4,4)
print(tf.tensor_scatter_nd_update(ref, indices, updates))
print(ref.scatter_nd_update(indices, updates))
#print(updates.shape[-1]==shape[-1], updates.shape[0]<=shape[0])
#conditions are:
# updates.shape[0]==indices[0]
# indices[1]<=len(shape)
# tf.rank(indices)==2
You also could understand the indices with the following psudo codes:
def scatter_nd_update(ref, indices, updates):
for i in range(tf.shape(indices)[0]):
ref[indices[i]]=updates[i]
return ref
Comapring with numpy's fancy indexing feature, tensorflow's indexing features are still very difficult to use and have different using style, not unified as same as numpy yet. Hope the situation could be better in tf3.x
Upvotes: 1