Reputation: 1614
Let us use the following code
#!/usr/bin/env python3
# encoding: utf-8
import numpy as np, tensorflow as tf # tf.__version__==2.7.0
sample_array=np.random.uniform(size=(2**10, 120, 20))
to_select=[5, 6, 9, 4]
sample_tensor=tf.convert_to_tensor(value=sample_array)
sample_array[:, :, to_select] # Works okay
sample_tensor[:, :, to_select] # TypeError. How to do this in tensor?
tf.convert_to_tensor(value=sample_tensor.numpy()[:, :, to_select]) # Ugly workaround
Basically, how to get those elements as a tensor of appropriate dimension, just like numpy? I tried tf.slice
and tf.gather
, but cannot figure out the proper arguments to pass.
I can convert it to numpy and back, but not sure if it will sacrifice the operation's efficiency, and work as part of a custom training loop.
Upvotes: 1
Views: 227
Reputation: 26698
The simplest solution would be to use tf.concat
, although it is probably not so efficient:
import numpy as np
import tensorflow as tf
sample_array = np.random.uniform(size=(2, 2, 20))
to_select = [5, 6, 9, 4]
sample_tensor = tf.convert_to_tensor(value = sample_array)
numpy_way = sample_array[:, :, to_select]
tf_way = tf.concat([tf.expand_dims(sample_array[:, :, to_select[i]], axis=-1) for i in tf.range(len(to_select))], axis=-1)
#tf_way = tf.concat([tf.expand_dims(sample_array[:, :, s], axis=-1) for s in to_select], axis=-1)
print(numpy_way)
print(tf_way)
[[[0.81208086 0.03873406 0.89959868 0.97896671]
[0.57569184 0.33659472 0.32566287 0.58383079]]
[[0.59984846 0.43405048 0.42366314 0.25505199]
[0.16180442 0.5903358 0.21302399 0.86569914]]]
tf.Tensor(
[[[0.81208086 0.03873406 0.89959868 0.97896671]
[0.57569184 0.33659472 0.32566287 0.58383079]]
[[0.59984846 0.43405048 0.42366314 0.25505199]
[0.16180442 0.5903358 0.21302399 0.86569914]]], shape=(2, 2, 4), dtype=float64)
A more complicated, but efficient solution would involve using tf.meshgrid
and tf.gather_nd
. Check this post or this post and finally this. Here is an example based on your question:
to_select = tf.expand_dims(tf.constant([5, 6, 9, 4]), axis=0)
to_select_shape = tf.shape(to_select)
sample_tensor_shape = tf.shape(sample_tensor)
to_select = tf.expand_dims(tf.reshape(tf.tile(to_select, [1, to_select_shape[1]]), (sample_tensor_shape[0], sample_tensor_shape[0] * to_select_shape[1])), axis=-1)
ij = tf.stack(tf.meshgrid(
tf.range(sample_tensor_shape[0], dtype=tf.int32),
tf.range(sample_tensor_shape[1], dtype=tf.int32),
indexing='ij'), axis=-1)
gather_indices = tf.concat([tf.repeat(ij, repeats=to_select_shape[1], axis=1), to_select], axis=-1)
gather_indices = tf.reshape(gather_indices, (to_select_shape[1], to_select_shape[1], 3))
result = tf.gather_nd(sample_tensor, gather_indices, batch_dims=0)
result = tf.reshape(result, (result.shape[0]//2, result.shape[0]//2, result.shape[1]))
tf.Tensor(
[[[0.81208086 0.03873406 0.89959868 0.97896671]
[0.57569184 0.33659472 0.32566287 0.58383079]]
[[0.59984846 0.43405048 0.42366314 0.25505199]
[0.16180442 0.5903358 0.21302399 0.86569914]]], shape=(2, 2, 4), dtype=float64)
Upvotes: 1