Della
Della

Reputation: 1614

How to Do Numpy Like Index Selection in Tensorflow?

Let us use the following code

#!/usr/bin/env python3
# encoding: utf-8
import numpy as np, tensorflow as tf # tf.__version__==2.7.0
sample_array=np.random.uniform(size=(2**10, 120, 20))
to_select=[5, 6, 9, 4]
sample_tensor=tf.convert_to_tensor(value=sample_array)

sample_array[:, :, to_select] # Works okay
sample_tensor[:, :, to_select] # TypeError. How to do this in tensor? 
tf.convert_to_tensor(value=sample_tensor.numpy()[:, :, to_select]) # Ugly workaround

Basically, how to get those elements as a tensor of appropriate dimension, just like numpy? I tried tf.slice and tf.gather, but cannot figure out the proper arguments to pass.

I can convert it to numpy and back, but not sure if it will sacrifice the operation's efficiency, and work as part of a custom training loop.

Upvotes: 1

Views: 227

Answers (1)

AloneTogether
AloneTogether

Reputation: 26698

The simplest solution would be to use tf.concat, although it is probably not so efficient:

import numpy as np
import tensorflow as tf

sample_array = np.random.uniform(size=(2, 2, 20))
to_select = [5, 6, 9, 4]
sample_tensor = tf.convert_to_tensor(value = sample_array)
numpy_way = sample_array[:, :, to_select]
tf_way = tf.concat([tf.expand_dims(sample_array[:, :, to_select[i]], axis=-1) for i in tf.range(len(to_select))], axis=-1)
#tf_way = tf.concat([tf.expand_dims(sample_array[:, :, s], axis=-1) for s in to_select], axis=-1)

print(numpy_way)
print(tf_way)
[[[0.81208086 0.03873406 0.89959868 0.97896671]
  [0.57569184 0.33659472 0.32566287 0.58383079]]

 [[0.59984846 0.43405048 0.42366314 0.25505199]
  [0.16180442 0.5903358  0.21302399 0.86569914]]]
tf.Tensor(
[[[0.81208086 0.03873406 0.89959868 0.97896671]
  [0.57569184 0.33659472 0.32566287 0.58383079]]

 [[0.59984846 0.43405048 0.42366314 0.25505199]
  [0.16180442 0.5903358  0.21302399 0.86569914]]], shape=(2, 2, 4), dtype=float64)

A more complicated, but efficient solution would involve using tf.meshgrid and tf.gather_nd. Check this post or this post and finally this. Here is an example based on your question:

to_select = tf.expand_dims(tf.constant([5, 6, 9, 4]), axis=0)
to_select_shape = tf.shape(to_select)
sample_tensor_shape = tf.shape(sample_tensor)
to_select = tf.expand_dims(tf.reshape(tf.tile(to_select, [1, to_select_shape[1]]), (sample_tensor_shape[0], sample_tensor_shape[0] * to_select_shape[1])), axis=-1)

ij = tf.stack(tf.meshgrid(
    tf.range(sample_tensor_shape[0], dtype=tf.int32), 
    tf.range(sample_tensor_shape[1], dtype=tf.int32),
                              indexing='ij'), axis=-1)

gather_indices = tf.concat([tf.repeat(ij, repeats=to_select_shape[1], axis=1), to_select], axis=-1)
gather_indices = tf.reshape(gather_indices, (to_select_shape[1], to_select_shape[1], 3))

result = tf.gather_nd(sample_tensor, gather_indices, batch_dims=0)
result = tf.reshape(result, (result.shape[0]//2, result.shape[0]//2, result.shape[1]))
tf.Tensor(
[[[0.81208086 0.03873406 0.89959868 0.97896671]
  [0.57569184 0.33659472 0.32566287 0.58383079]]

 [[0.59984846 0.43405048 0.42366314 0.25505199]
  [0.16180442 0.5903358  0.21302399 0.86569914]]], shape=(2, 2, 4), dtype=float64)

Upvotes: 1

Related Questions