DsCpp
DsCpp

Reputation: 2489

tf.gather with indices of higher dimention than input data?

Reading Dynamic Graph CNN for Learning on Point Clouds code, I came across this snippet:

  idx_ = tf.range(batch_size) * num_points
  idx_ = tf.reshape(idx_, [batch_size, 1, 1]) 

  point_cloud_flat = tf.reshape(point_cloud, [-1, num_dims])
  point_cloud_neighbors = tf.gather(point_cloud_flat, nn_idx+idx_)  <--- what happens here?
  point_cloud_central = tf.expand_dims(point_cloud_central, axis=-2)

debugging the line I made sure that the dims are

point_cloud_flat:(32768,3) nn_idx:(32,1024,20), idx_:(32,1,1) 
// indices are (32,1024,20) after broadcasting

Reading the tf.gather doc I couldn't understand what the function does with dimensions higher that the input dimensions

Upvotes: 0

Views: 446

Answers (1)

LI Xuhong
LI Xuhong

Reputation: 2356

An equivalent function in numpy is np.take, a simple example:

import numpy as np

params = np.array([4, 3, 5, 7, 6, 8])

# Scalar indices; (output is rank(params) - 1), i.e. 0 here.
indices = 0
print(params[indices])

# Vector indices; (output is rank(params)), i.e. 1 here.
indices = [0, 1, 4]
print(params[indices])  # [4 3 6]

# Vector indices; (output is rank(params)), i.e. 1 here.
indices = [2, 3, 4]
print(params[indices])  # [5 7 6]

# Higher rank indices; (output is rank(params) + rank(indices) - 1), i.e. 2 here
indices = np.array([[0, 1, 4], [2, 3, 4]])
print(params[indices])  # equivalent to np.take(params, indices, axis=0)
# [[4 3 6]
# [5 7 6]]

In your case, the rank of indices is higher than params, so output is rank(params) + rank(indices) - 1 (i.e. 2 + 3 - 1 = 4, i.e. (32, 1024, 20, 3)). The - 1 is because the tf.gather(axis=0) and axis must be rank 0 (so a scalar) at this moment. So the indices takes the elements of the first dimension (axis=0) in a "fancy" indexing way.

EDITED:

In brief, in your case, (if I didn't misunderstand the code)

  • point_cloud is (32, 1024, 3), 32 batches 1024 points which have 3 coordinates.
  • nn_idx is (32, 1024, 20), indices of 20 neighbors of 32 batches 1024 points. The indices are for indexing in point_cloud.
  • nn_idx+idx_ (32, 1024, 20), indices of 20 neighbors of 32 batches 1024 points. The indices are for indexing in point_cloud_flat.
  • point_cloud_neighbors finally is (32, 1024, 20, 3), the same as nn_idx+idx_ except that point_cloud_neighbors are their 3 coordinates while nn_idx+idx_ are just their indices.

Upvotes: 1

Related Questions