Reputation: 225
Suppose I have a tensor A
of shape (m, n)
, I would like to randomly sample k
elements (without replacement) from each row, resulting in a tensor B
of shape (m, k)
. How to do that in tensorflow?
An example would be:
A
: [[1,2,3], [4,5,6], [7,8,9], [10,11,12]]
k
: 2
B
: [[1,3],[5,6],[9,8],[12,10]]
Upvotes: 1
Views: 524
Reputation: 59731
This is a way to do that:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
tf.random.set_random_seed(0)
a = tf.constant([[1,2,3], [4,5,6], [7,8,9], [10,11,12]], tf.int32)
k = tf.constant(2, tf.int32)
# Tranpose, shuffle, slice, undo transpose
aT = tf.transpose(a)
aT_shuff = tf.random.shuffle(aT)
at_shuff_k = aT_shuff[:k]
result = tf.transpose(at_shuff_k)
print(sess.run(result))
# [[ 3 1]
# [ 6 4]
# [ 9 7]
# [12 10]]
Upvotes: 1