shaunshd
shaunshd

Reputation: 95

What is the difference between tf.scatter_add and tf.scatter_nd when indices is a matrix?

Both tf.scatter_add and tf.scatter_nd allow indices to be a matrix. It is clear from the documentation of tf.scatter_nd that the last dimension of indices contains values that are used to index a tensor of shape shape. The other dimensions of indices define the number of elements/slices to be scattered. Suppose updates has a rank N. First k dimensions of indices (except the last dimension) should match with first k dimensions of updates. The last (N-k) dimensions of updates should match with the last (N-k) dimensions of shape.

This implies that tf.scatter_nd can be used to perform an N-dimensional scatter. However, tf.scatter_add also takes matrices as indices. But, its not clear which dimensions of indices correspond to the number of scatters to be performed and how do these dimensions align with updates. Can someone provide a clear explanation possibly with examples?

Upvotes: 2

Views: 1103

Answers (1)

Clock ZHONG
Clock ZHONG

Reputation: 980

@shaunshd , I finally fully understand the 3 tensors relationship in tf.scatter_nd_*() arguments, especially when the indices have multi-demensions. e.g: indices = tf.constant([[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [3,3,2]], dtype=tf.int32)

Please don't expect tf.rank(indices)>2, tf.rank(indices)==2 is permanently true;

The following is my test codes to show more complex test case than the examples provided in tensroflow's official website:

def testScatterNDUpdate(self):
    ref = tf.Variable(np.zeros(shape=[4, 4, 4], dtype=np.float32))
    indices = tf.constant([[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [3,3,2]], dtype=tf.int32)
    updates = tf.constant([1,2,3,4,5], dtype=tf.float32)
    #shape = (4,4,4)
    print(tf.tensor_scatter_nd_update(ref, indices, updates))
    print(ref.scatter_nd_update(indices, updates))
    #print(updates.shape[-1]==shape[-1], updates.shape[0]<=shape[0])
    #conditions are:
    #      updates.shape[0]==indices[0]
    #      indices[1]<=len(shape)
    #      tf.rank(indices)==2

You also could understand the indices with the following psudo codes:

def scatter_nd_update(ref, indices, updates):
    for i in range(tf.shape(indices)[0]):
        ref[indices[i]]=updates[i]
    return ref

Comapring with numpy's fancy indexing feature, tensorflow's indexing features are still very difficult to use and have different using style, not unified as same as numpy yet. Hope the situation could be better in tf3.x

Upvotes: 1

Related Questions