Steve Yang
Steve Yang

Reputation: 225

Tensorflow randomly sample from each row

Suppose I have a tensor A of shape (m, n), I would like to randomly sample k elements (without replacement) from each row, resulting in a tensor B of shape (m, k). How to do that in tensorflow?

An example would be:

A: [[1,2,3], [4,5,6], [7,8,9], [10,11,12]]

k: 2

B: [[1,3],[5,6],[9,8],[12,10]]

Upvotes: 1

Views: 524

Answers (1)

javidcf
javidcf

Reputation: 59731

This is a way to do that:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    tf.random.set_random_seed(0)
    a = tf.constant([[1,2,3], [4,5,6], [7,8,9], [10,11,12]], tf.int32)
    k = tf.constant(2, tf.int32)
    # Tranpose, shuffle, slice, undo transpose
    aT = tf.transpose(a)
    aT_shuff = tf.random.shuffle(aT)
    at_shuff_k = aT_shuff[:k]
    result = tf.transpose(at_shuff_k)
    print(sess.run(result))
    # [[ 3  1]
    #  [ 6  4]
    #  [ 9  7]
    #  [12 10]]

Upvotes: 1

Related Questions