Andrzej Pronobis
Andrzej Pronobis

Reputation: 36096

How do I select certain columns of a 2D tensor in TensorFlow?

As generalized slicing is being worked on in this issue, what would be the best way to achieve an op gathering columns of a 2D tensor (matrix)? For example, for tensor t:

1 2 3 4
5 6 7 8 

and indices [1,3], I would like to get:

2 4
6 8

which is equivalent to numpy t[:, [1,3]].

Upvotes: 19

Views: 27819

Answers (3)

AlexConfused
AlexConfused

Reputation: 831

Meanwhile the gather method has an axis parameter.

import tensorflow as tf
params = tf.constant([[1,2,3],[4,5,6]])
indices = [0,2]
op = tf.gather(params, indices, axis=1)

produces the output

[[1 3]
 [4 6]]

Upvotes: 27

Andrzej Pronobis
Andrzej Pronobis

Reputation: 36096

So far, I created a workaround by flattening the input and using gather:

def gather_cols(params, indices, name=None):
    """Gather columns of a 2D tensor.

    Args:
        params: A 2D tensor.
        indices: A 1D tensor. Must be one of the following types: ``int32``, ``int64``.
        name: A name for the operation (optional).

    Returns:
        A 2D Tensor. Has the same type as ``params``.
    """
    with tf.op_scope([params, indices], name, "gather_cols") as scope:
        # Check input
        params = tf.convert_to_tensor(params, name="params")
        indices = tf.convert_to_tensor(indices, name="indices")
        try:
            params.get_shape().assert_has_rank(2)
        except ValueError:
            raise ValueError('\'params\' must be 2D.')
        try:
            indices.get_shape().assert_has_rank(1)
        except ValueError:
            raise ValueError('\'indices\' must be 1D.')

        # Define op
        p_shape = tf.shape(params)
        p_flat = tf.reshape(params, [-1])
        i_flat = tf.reshape(tf.reshape(tf.range(0, p_shape[0]) * p_shape[1],
                                       [-1, 1]) + indices, [-1])
        return tf.reshape(tf.gather(p_flat, i_flat),
                          [p_shape[0], -1])

Which for:

params = tf.constant([[1, 2, 3],
                      [4, 5, 6]])
indices = [0, 2]
op = gather_cols(params, indices)

produces the expected output:

[[1 3]
 [4 6]]

Upvotes: 5

lucky6qi
lucky6qi

Reputation: 994

There is a function named tf.nn.embedding_lookup(params, ind) which retrieves the rows of the params tensor.

To achieve what you want, we can first transpose the tensor t from which you want to select certain columns from. Then look up the rows of tf.transpose(t) (columns of t). After the selection, we transpose the result back.

import tensorflow as tf


t = tf.constant([[1, 2, 3], 
                 [4, 5, 6]])
ind = tf.constant([0, 2])

result = tf.transpose(tf.nn.embedding_lookup(tf.transpose(t), ind))

with tf.Session() as sess:
    print(sess.run(result))

Upvotes: 9

Related Questions