Abhishek Bhatia
Abhishek Bhatia

Reputation: 9796

Splitting tensors in tensorflow

I want to split tensor into two parts:

ipdb> mean_log_std
<tf.Tensor 'pi/add_5:0' shape=(?, 2) dtype=float32>

Context: ? is for number of samples and the other dimension is 2. I want to split along the second dimension into two tensorflow of shape 1 along that dimension.

What I tried?(https://www.tensorflow.org/api_docs/python/tf/slice)

ipdb> tf.slice(mean_log_std,[0,2],[0,1])
<tf.Tensor 'pi/Slice_6:0' shape=(0, 1) dtype=float32>
ipdb> tf.slice(mean_log_std,[0,1],[0,1])
<tf.Tensor 'pi/Slice_7:0' shape=(0, 1) dtype=float32>
ipdb>

I would expect the shape to be (?,1) and (?,1) for the above two splits.

Upvotes: 3

Views: 10715

Answers (1)

akuiper
akuiper

Reputation: 214927

You can slice the tensor at the second dimension with:

x[:,0:1], x[:,1:2]

Or split on the second axis:

y, z = tf.split(x, 2, axis=1)

Example:

import tensorflow as tf

x = tf.placeholder(tf.int32, shape=[None, 2])

y, z = x[:,0:1], x[:,1:2]

y
#<tf.Tensor 'strided_slice_2:0' shape=(?, 1) dtype=int32>

z
#<tf.Tensor 'strided_slice_3:0' shape=(?, 1) dtype=int32>

with tf.Session() as sess:
    print(sess.run(y, {x: [[1,2],[3,4]]}))
    print(sess.run(z, {x: [[1,2],[3,4]]}))
#[[1]
# [3]]
#[[2]
# [4]]

With split:

y, z = tf.split(x, 2, axis=1)

with tf.Session() as sess:
    print(sess.run(y, {x: [[1,2],[3,4]]}))
    print(sess.run(z, {x: [[1,2],[3,4]]}))
#[[1]
# [3]]
#[[2]
# [4]]

Upvotes: 10

Related Questions