Reputation: 2489
Reading Dynamic Graph CNN for Learning on Point Clouds code, I came across this snippet:
idx_ = tf.range(batch_size) * num_points
idx_ = tf.reshape(idx_, [batch_size, 1, 1])
point_cloud_flat = tf.reshape(point_cloud, [-1, num_dims])
point_cloud_neighbors = tf.gather(point_cloud_flat, nn_idx+idx_) <--- what happens here?
point_cloud_central = tf.expand_dims(point_cloud_central, axis=-2)
debugging the line I made sure that the dims are
point_cloud_flat:(32768,3) nn_idx:(32,1024,20), idx_:(32,1,1)
// indices are (32,1024,20) after broadcasting
Reading the tf.gather doc I couldn't understand what the function does with dimensions higher that the input dimensions
Upvotes: 0
Views: 446
Reputation: 2356
An equivalent function in numpy is np.take
, a simple example:
import numpy as np
params = np.array([4, 3, 5, 7, 6, 8])
# Scalar indices; (output is rank(params) - 1), i.e. 0 here.
indices = 0
print(params[indices])
# Vector indices; (output is rank(params)), i.e. 1 here.
indices = [0, 1, 4]
print(params[indices]) # [4 3 6]
# Vector indices; (output is rank(params)), i.e. 1 here.
indices = [2, 3, 4]
print(params[indices]) # [5 7 6]
# Higher rank indices; (output is rank(params) + rank(indices) - 1), i.e. 2 here
indices = np.array([[0, 1, 4], [2, 3, 4]])
print(params[indices]) # equivalent to np.take(params, indices, axis=0)
# [[4 3 6]
# [5 7 6]]
In your case, the rank of indices
is higher than params
, so output is rank(params
) + rank(indices
) - 1 (i.e. 2 + 3 - 1 = 4, i.e. (32, 1024, 20, 3)). The - 1
is because the tf.gather(axis=0)
and axis
must be rank 0 (so a scalar) at this moment. So the indices
takes the elements of the first dimension (axis=0
) in a "fancy" indexing way.
EDITED:
In brief, in your case, (if I didn't misunderstand the code)
point_cloud
is (32, 1024, 3), 32 batches 1024 points which have 3
coordinates. nn_idx
is (32, 1024, 20), indices of 20 neighbors of
32 batches 1024 points. The indices are for indexing in point_cloud
. nn_idx+idx_
(32, 1024, 20), indices of 20 neighbors of
32 batches 1024 points. The indices are for indexing in point_cloud_flat
.point_cloud_neighbors
finally is (32, 1024,
20, 3), the same as nn_idx+idx_
except that point_cloud_neighbors
are their 3 coordinates while nn_idx+idx_
are just their indices.Upvotes: 1