Reputation: 13
I am doing a project but their tensorflow version does not support tf.gather_nd. I am asking if possible that use tf.gather, tf.slice or tf.strided_slice to rewrite a function of tf.gather_nd?
tf.gather_nd is used to gather slices from a tensor into a Tensor with shape specified by indices. details can be found in https://www.tensorflow.org/api_docs/python/tf/gather_nd
Thanks,
Upvotes: 1
Views: 1283
Reputation: 690
For other people who would like implement tf.gather_nd in pytorch, see https://discuss.pytorch.org/t/how-to-do-the-tf-gather-nd-in-pytorch/6445/37 and his colab notebook. I adapt it a little bit to implement it by numpy:
def gather_nd(params, indices, batch_dims=0):
""" use numpy and tensorflow to implement tf.gather_nd
Adapt from : https://discuss.pytorch.org/t/how-to-do-the-tf-gather-nd-in-pytorch/6445/37
"""
# firstly, convert to numpy type, then use numpy to execute operations
if isinstance(params, tf.Tensor):
params = params.numpy()
else:
if not isinstance(indices, np.ndarray):
raise ValueError(f'params must be `tf.Tensor` or `numpy.ndarray`. Got {type(params)}')
if isinstance(indices, tf.Tensor):
indices = indices.numpy()
else:
if not isinstance(indices, np.ndarray):
raise ValueError(f'indices must be `tf.Tensor` or `numpy.ndarray`. Got {type(indices)}')
if batch_dims == 0:
orig_shape = list(indices.shape)
num_samples = int(np.prod(orig_shape[:-1]))
m = orig_shape[-1]
n = len(params.shape)
if m <= n:
out_shape = orig_shape[:-1] + list(params.shape[m:])
else:
raise ValueError(
f'the last dimension of indices must less or equal to the rank of params. Got indices:{indices.shape}, params:{params.shape}. {m} > {n}'
)
# indices_ = tf.transpose(tf.reshape(indices, [num_samples, m]), perm=[1, 0])
indices = indices.reshape((num_samples, m)).transpose().tolist()
output = params[indices] # (num_samples, ...)
return tf.reshape(output,out_shape) # or return numpy type: output.reshape(out_shape)
else:
batch_shape = params.shape[:batch_dims]
orig_indices_shape = list(indices.shape)
orig_params_shape = list(params.shape)
assert (
batch_shape == indices.shape[:batch_dims]
), f'if batch_dims is not 0, then both "params" and "indices" have batch_dims leading batch dimensions that exactly match.'
mbs = np.prod(batch_shape)
if batch_dims != 1:
params = params.reshape(mbs, *(params.shape[batch_dims:]))
indices = indices.reshape(mbs, *(indices.shape[batch_dims:]))
output = []
for i in range(mbs):
output.append(gather_nd(params[i], indices[i], batch_dims=0))
output =np.stack(output, axis=0)
output_shape = orig_indices_shape[:-1] + list(orig_params_shape[orig_indices_shape[-1]+batch_dims:])
return tf.reshape(output,output_shape) # or return numpy type: output.reshape(output_shape)
Upvotes: 0
Reputation: 59701
This function should do an equivalent work:
import tensorflow as tf
import numpy as np
def my_gather_nd(params, indices):
idx_shape = tf.shape(indices)
params_shape = tf.shape(params)
idx_dims = idx_shape[-1]
gather_shape = params_shape[idx_dims:]
params_flat = tf.reshape(params, tf.concat([[-1], gather_shape], axis=0))
axis_step = tf.cumprod(params_shape[:idx_dims], exclusive=True, reverse=True)
indices_flat = tf.reduce_sum(indices * axis_step, axis=-1)
result_flat = tf.gather(params_flat, indices_flat)
return tf.reshape(result_flat, tf.concat([idx_shape[:-1], gather_shape], axis=0))
# Test
np.random.seed(0)
with tf.Graph().as_default(), tf.Session() as sess:
params = tf.constant(np.random.rand(10, 20, 30).astype(np.float32))
indices = tf.constant(np.stack([np.random.randint(10, size=(5, 8)),
np.random.randint(20, size=(5, 8))], axis=-1))
result1, result2 = sess.run((tf.gather_nd(params, indices),
my_gather_nd(params, indices)))
print(np.allclose(result1, result2))
# True
Upvotes: 1