FesianXu
FesianXu

Reputation: 447

In TensorFlow, how to reshape a tensor in a specific axis?

Now i have a tensor with shape (3*2, 2) look like

enter image description here

and i want to reshape it with the shape (3, 2*2) with the specific axis like following:

enter image description here

what should i do? The default tf.reshape() will reshape it to

enter image description here

SOLUTION: I found that use slice in tensorflow and tf.concat() can solve the problem.You can slice sub-tensors and concat them which solve my problem exactly

Upvotes: 1

Views: 1234

Answers (1)

Nipun Wijerathne
Nipun Wijerathne

Reputation: 1829

I tried the following code and got the result that you need. But not sure whether the number of steps can be reduced.

import tensorflow as tf

x = [[1, 2],
     [3,4],
     [5,6],
     [7,8],
     [9,10],
     [11,12]]

a = tf.reshape(x,[-1,6])
b = tf.split(a,3, 1)
c = tf.reshape(b,[-1,4])

X=tf.placeholder(tf.float32, shape=[6, 2], name='input')

with tf.Session() as sess:
      c =  sess.run(c, feed_dict={X: x})
      print(c)

Hope this helps.

Upvotes: 1

Related Questions