M. Merida-Floriano
M. Merida-Floriano

Reputation: 357

How to use `tf.scatter_nd` with multi-dimensional tensors

I'm trying to create a new tensor (output) with the values of another tensor (updates) placed according to idx tensor. The shape of output should be [batch_size, 1, 4, 4] (like an image of 2x2 pixels and one channel) and update has shape [batch_size, 3].

I've read Tensorflow documentation (I'm working with gpu version 1.13.1) and found tf.scatter_nd should work for my problem. The issue is that I cannot make it work, I think I'm having problems understanding how I have to arange idx.

Let's consider batch_size = 2, so what I'm doing is:

updates = tf.constant([[1, 2, 3], [4, 5, 6]])  # shape [2, 3]
output_shape = tf.constant([2, 1, 4, 4])
idx = tf.constant([[[1, 0], [1, 1], [1, 0]], [[0, 0], [0, 1], [0, 2]]])  # shape [2, 3, 2]
idx_expanded = tf.expand_dims(idx, 1)  # so I have shape [2, 1, 3, 2]
output = tf.scatter_nd(idx_expanded, updates, output_shape)

I expect it to work, but it doesn't, it gives me this error:

ValueError: The outer 3 dimensions of indices.shape=[2,1,3,2] must match the outer 3 dimensions of updates.shape=[2,3]: Shapes must be equal rank, but are 3 and 2 for 'ScatterNd_7' (op: 'ScatterNd') with input shapes: [2,1,3,2], [2,3], [4]

I don't understand why it's expecting updates to have dimension 3. I thought idx has to make sense with output_shape (that's why I used expand_dims) and also with updates (specify the two indices for the three points), but it's obvious I'm missing something here.

Any help would be appreciated.

Upvotes: 3

Views: 2679

Answers (1)

M. Merida-Floriano
M. Merida-Floriano

Reputation: 357

I've been playing around with the function and I have found my mistake. If anyone is facing this problem, this is what I did to solve it:

Considering batch_size=2 and 3 points, idx tensor must have shape [2, 3, 4], where first dimension correspond to the batch from where we are taking updatevalue, second dimension must be equal to the second dimension of updates (number of points per batch) and the third dimension is 4 because we need 4 indices: [batch_number, channel, row, col]. Following the example in the question:

updates = tf.constant([[1., 2., 3.], [4., 5., 6.]])  # [2, 3]
idx = tf.constant([[[0, 0, 0, 1], [0, 0, 0, 0], [0, 0, 1, 0]], [[1, 0, 1, 1], [1, 0, 0, 0], [1, 0, 1, 0]]])  # [2, 3, 4]
output = tf.scatter_nd(idx, updates, [2, 1, 4, 4])

sess = tf.Session()
print(sess.run(output))

[[[[2. 1. 0. 0.]
   [3. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]]


 [[[5. 0. 0. 0.]
   [6. 4. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]]]

This way it's possible to place specific numbers in a new tensor.

Upvotes: 6

Related Questions