Peter
Peter

Reputation: 13

replacement of "tf.gather_nd"

I am doing a project but their tensorflow version does not support tf.gather_nd. I am asking if possible that use tf.gather, tf.slice or tf.strided_slice to rewrite a function of tf.gather_nd?

tf.gather_nd is used to gather slices from a tensor into a Tensor with shape specified by indices. details can be found in https://www.tensorflow.org/api_docs/python/tf/gather_nd

Thanks,

Upvotes: 1

Views: 1283

Answers (2)

Wade Wang
Wade Wang

Reputation: 690

For other people who would like implement tf.gather_nd in pytorch, see https://discuss.pytorch.org/t/how-to-do-the-tf-gather-nd-in-pytorch/6445/37 and his colab notebook. I adapt it a little bit to implement it by numpy:

def gather_nd(params, indices, batch_dims=0):
    """ use numpy and tensorflow to implement tf.gather_nd
    Adapt from : https://discuss.pytorch.org/t/how-to-do-the-tf-gather-nd-in-pytorch/6445/37
    """
    # firstly, convert to numpy type, then use numpy to execute operations
    if isinstance(params, tf.Tensor):
      params = params.numpy()
    else:
      if not isinstance(indices, np.ndarray):
        raise ValueError(f'params must be `tf.Tensor` or `numpy.ndarray`. Got {type(params)}')
    if isinstance(indices, tf.Tensor):
      indices = indices.numpy()
    else:
      if not isinstance(indices, np.ndarray):
        raise ValueError(f'indices must be `tf.Tensor` or `numpy.ndarray`. Got {type(indices)}')

    if batch_dims == 0:
        orig_shape = list(indices.shape)
        num_samples = int(np.prod(orig_shape[:-1]))
        m = orig_shape[-1]
        n = len(params.shape)

        if m <= n:
            out_shape = orig_shape[:-1] + list(params.shape[m:])
        else:
            raise ValueError(
                f'the last dimension of indices must less or equal to the rank of params. Got indices:{indices.shape}, params:{params.shape}. {m} > {n}'
            )
        # indices_ = tf.transpose(tf.reshape(indices, [num_samples, m]), perm=[1, 0])
        indices = indices.reshape((num_samples, m)).transpose().tolist()
        output = params[indices]    # (num_samples, ...)

        return tf.reshape(output,out_shape)  # or return numpy type: output.reshape(out_shape)
    else:
        batch_shape = params.shape[:batch_dims]
        orig_indices_shape = list(indices.shape)
        orig_params_shape = list(params.shape)
        assert (
            batch_shape == indices.shape[:batch_dims]
        ), f'if batch_dims is not 0, then both "params" and "indices" have batch_dims leading batch dimensions that exactly match.'
        mbs = np.prod(batch_shape)
        if batch_dims != 1:
            params = params.reshape(mbs, *(params.shape[batch_dims:]))
            indices = indices.reshape(mbs, *(indices.shape[batch_dims:]))
        output = []
        for i in range(mbs):
            output.append(gather_nd(params[i], indices[i], batch_dims=0))
        output =np.stack(output, axis=0)
        output_shape = orig_indices_shape[:-1] + list(orig_params_shape[orig_indices_shape[-1]+batch_dims:])
        return tf.reshape(output,output_shape)  # or return numpy type: output.reshape(output_shape)

Upvotes: 0

javidcf
javidcf

Reputation: 59701

This function should do an equivalent work:

import tensorflow as tf
import numpy as np

def my_gather_nd(params, indices):
    idx_shape = tf.shape(indices)
    params_shape = tf.shape(params)
    idx_dims = idx_shape[-1]
    gather_shape = params_shape[idx_dims:]
    params_flat = tf.reshape(params, tf.concat([[-1], gather_shape], axis=0))
    axis_step = tf.cumprod(params_shape[:idx_dims], exclusive=True, reverse=True)
    indices_flat = tf.reduce_sum(indices * axis_step, axis=-1)
    result_flat = tf.gather(params_flat, indices_flat)
    return tf.reshape(result_flat, tf.concat([idx_shape[:-1], gather_shape], axis=0))

# Test
np.random.seed(0)
with tf.Graph().as_default(), tf.Session() as sess:
    params = tf.constant(np.random.rand(10, 20, 30).astype(np.float32))
    indices = tf.constant(np.stack([np.random.randint(10, size=(5, 8)),
                                    np.random.randint(20, size=(5, 8))], axis=-1))
    result1, result2 = sess.run((tf.gather_nd(params, indices),
                                 my_gather_nd(params, indices)))
    print(np.allclose(result1, result2))
    # True

Upvotes: 1

Related Questions