TinyEpic
TinyEpic

Reputation: 561

Slice multiple slices at once with tensorflow

I am preparing the input tensor for the tensorflow RNN.
Currently I am doing the following

rnn_format = list()
for each in range(batch_size):
    rnn_format.append(tf.slice(input2Dpadded,[each,0],[max_steps,10]))
lstm_input = tf.stack(rnn_format)

Would it be possible to do this at once, without loop, with some tensorflow function?

Upvotes: 6

Views: 4363

Answers (2)

P-Gn
P-Gn

Reputation: 24581

As suggested by Peter Hawkins, you can use gather_nd with the appropriate indices to get there.

Your uniform cropping on the inner dimension can simply be done before the call to gather_nd.

Example:

import tensorflow as tf
import numpy as np

sess = tf.InteractiveSession()

# integer image simply because it is more readable to me
im0 = np.random.randint(10, size=(20,20))
im = tf.constant(im0)

max_steps = 3
batch_size = 10

# create the appropriate indices here
indices = (np.arange(max_steps) +
    np.arange(batch_size)[:,np.newaxis])[...,np.newaxis]
# crop then call gather_nd
res = tf.gather_nd(im[:,:10], indices).eval()

# check that the resulting tensors are equal to what you had previously
for each in range(batch_size):
  assert(np.all(tf.slice(im, [each,0],[max_steps,10]).eval() == res[each]))

EDIT

If your slices indices are in a tensor, you simply replace numpy's operations with tensorflow's operations when creating indices:

# indices stored in a 1D array
my_indices = tf.constant([1, 8, 3, 0, 0])
indices = (np.arange(max_steps) +
    my_indices[:,tf.newaxis])[...,tf.newaxis]

Further remarks:

  • indices is created by taking advantage of broadcasting during the addition: arrays are virtually tiled so that their dimensions match. Broadcasting is supported by numpy and by tensorflow in a similar fashion.
  • Ellipsis ... is part of the standard numpy slicing notation, it basically fills all remaining dimensions left by the other slicing indices. So [..., newaxis] is basically equivalent to expand_dims(·, -1).

Upvotes: 4

Peter Hawkins
Peter Hawkins

Reputation: 3211

Try tf.split or tf.split_v. See here:

https://www.tensorflow.org/api_docs/python/tf/split

Does that help?

Upvotes: 1

Related Questions