whiletrue
whiletrue

Reputation: 11075

TensorFlow: slice Tensor and keep original shape

I have a Tensor tensor of shape (?, 1082) and I want to slice this Tensor into n subparts in a for-loop but I want to keep the original shape, including the unknown dimension ?.

Example:

lst = []
for n in range(15):
    sub_tensor = tensor[n] # this will reduce the first dimension
    print(sub_tensor.get_shape())

Print output I'm looking for:

(?, 1082) 
(?, 1082)

etc.

How can this be achieved in TensorFlow?

Upvotes: 0

Views: 751

Answers (1)

Sharky
Sharky

Reputation: 4543

Considering that your problem can have many constraints, I can think of at least 3 solutions. You can use tf.split. I'll use tf.placeholder, but it's applicable to tensors and variables as well.

p = tf.placeholder(shape=[None,10], dtype=tf.int32)
s1, s2 = tf.split(value=p, num_or_size_splits=2, axis=1)

However, this approach can become unfeasible if number of splits required is large. Note that it can split None axis as well.

for n in range(15):
    sub_tensor = tensor[n, :] 
s = tf.slice(p, [0,2], [-1, 2])

Slice can be used for multidimensional tensors, but it' pretty tricky to use. And you can use tf.Tensor.getitem method, almost as you described in your question. It acts similar to NumPy. So this should do the job:

for n in range(10):
    print(p[n, :])

However, usage of these methods heavily depend on your particular application. Hope this helps.

Upvotes: 1

Related Questions