YW P Kwon
YW P Kwon

Reputation: 2168

In Tensorflow, how to use tf.gather() for the last dimension?

I am trying to gather slices of a tensor in terms of the last dimension for partial connection between layers. Because the output tensor's shape is [batch_size, h, w, depth], I want to select slices based on the last dimension, such as

# L is intermediate tensor
partL = L[:, :, :, [0,2,3,8]]

However, tf.gather(L, [0, 2,3,8]) seems to only work for the first dimension (right?) Can anyone tell me how to do it?

Upvotes: 18

Views: 44339

Answers (8)

rryan
rryan

Reputation: 807

As of TensorFlow 1.3 tf.gather has an axis parameter, so the various workarounds here are no longer necessary.

https://www.tensorflow.org/versions/r1.3/api_docs/python/tf/gather https://github.com/tensorflow/tensorflow/issues/11223

Upvotes: 29

Lerner Zhang
Lerner Zhang

Reputation: 7150

You can try this way, for instance(in most cases in NLP at the least),

The parameter is of shape [batch_size, depth] and the indices are [i, j, k, n, m] of which the length is batch_size. Then gather_nd can be helpful.

parameters = tf.constant([
                          [11, 12, 13], 
                          [21, 22, 23], 
                          [31, 32, 33], 
                          [41, 42, 43]])    
targets = tf.constant([2, 1, 0, 1])    
batch_nums = tf.range(0, limit=parameters.get_shape().as_list()[0])     
indices = tf.stack((batch_nums, targets), axis=1) # the axis is the dimension number   
items = tf.gather_nd(parameters, indices)  
# which is what we want: [13, 22, 31, 42]

This snippet first find the fist dimension through the batch_num and then fetch the item along that dimension by the target number.

Upvotes: 2

Edward Hughes
Edward Hughes

Reputation: 171

A correct version of @Andrei's answer would read

cat_idx = tf.stack([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=1)
result = tf.gather_nd(matrix, cat_idx)

Upvotes: 3

Adwin Jahn
Adwin Jahn

Reputation: 11

Tensor doesn't have attribute shape, but get_shape() method. Below is runnable by Python 2.7

import tensorflow as tf
import numpy as np
x = tf.constant([[1, 2, 3],
                 [4, 5, 6],
                 [7, 8, 9]])
idx = tf.constant([1, 0, 2])
idx_flattened = tf.range(0, x.get_shape()[0]) * x.get_shape()[1] + idx
y = tf.gather(tf.reshape(x, [-1]),  # flatten input
              idx_flattened)  # use flattened indices

with tf.Session(''):
  print y.eval()  # [2 4 9]

Upvotes: 1

Yunseong Hwang
Yunseong Hwang

Reputation: 51

Yet another solution using tf.unstack(...), tf.gather(...) and tf.stack(..)

Code:

import tensorflow as tf
import numpy as np

shape = [2, 2, 2, 10] 
L = np.arange(np.prod(shape))
L = np.reshape(L, shape)

indices = [0, 2, 3, 8]
axis = -1 # last dimension

def gather_axis(params, indices, axis=0):
    return tf.stack(tf.unstack(tf.gather(tf.unstack(params, axis=axis), indices)), axis=axis)

print(L)
with tf.Session() as sess:
    partL = sess.run(gather_axis(L, indices, axis))
    print(partL)

Result:

L = 
[[[[ 0  1  2  3  4  5  6  7  8  9]
   [10 11 12 13 14 15 16 17 18 19]]

  [[20 21 22 23 24 25 26 27 28 29]
   [30 31 32 33 34 35 36 37 38 39]]]


 [[[40 41 42 43 44 45 46 47 48 49]
   [50 51 52 53 54 55 56 57 58 59]]

  [[60 61 62 63 64 65 66 67 68 69]
   [70 71 72 73 74 75 76 77 78 79]]]]

partL = 
[[[[ 0  2  3  8]
   [10 12 13 18]]

  [[20 22 23 28]
   [30 32 33 38]]]


 [[[40 42 43 48]
   [50 52 53 58]]

  [[60 62 63 68]
   [70 72 73 78]]]]

Upvotes: 4

Andrei Pokrovsky
Andrei Pokrovsky

Reputation: 3846

With gather_nd you can now do this as follows:

cat_idx = tf.concat([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=0)
result = tf.gather_nd(matrix, cat_idx)

Also, as reported by user Nova in a thread referenced by @Yaroslav Bulatov's:

x = tf.constant([[1, 2, 3],
                 [4, 5, 6],
                 [7, 8, 9]])
idx = tf.constant([1, 0, 2])
idx_flattened = tf.range(0, x.shape[0]) * x.shape[1] + idx
y = tf.gather(tf.reshape(x, [-1]),  # flatten input
              idx_flattened)  # use flattened indices

with tf.Session(''):
  print y.eval()  # [2 4 9]

The gist is flatten the tensor and use strided 1D addressing with tf.gather(...).

Upvotes: 9

Sven Dorkenwald
Sven Dorkenwald

Reputation: 1

Implementing 2. from @Yaroslav Bulatov's:

#Your indices
indices = [0, 2, 3, 8]

#Remember for final reshaping
n_indices = tf.shape(indices)[0]

flattened_L = tf.reshape(L, [-1])

#Walk strided over the flattened array
offset = tf.expand_dims(tf.range(0, tf.reduce_prod(tf.shape(L)), tf.shape(L)[-1]), 1)
flattened_indices = tf.reshape(tf.reshape(indices, [-1])+offset, [-1])
selected_rows = tf.gather(flattened_L, flattened_indices)

#Final reshape
partL = tf.reshape(selected_rows, tf.concat(0, [tf.shape(L)[:-1], [n_indices]]))

Credit to How to select rows from a 3-D Tensor in TensorFlow?

Upvotes: 0

Yaroslav Bulatov
Yaroslav Bulatov

Reputation: 57983

There's a tracking bug to support this use-case here: https://github.com/tensorflow/tensorflow/issues/206

For now you can:

  1. transpose your matrix so that dimension to gather is first (transpose is expensive)

  2. reshape your tensor into 1d (reshape is cheap) and turn your gather column indices into a list of individual element indices at linear indexing, then reshape back

  3. use gather_nd. Will still need to turn your column indices into list of individual element indices.

Upvotes: 12

Related Questions