Reputation: 1217
For example, if we have:
a = tf.constant(np.eye(5))
a
<tf.Tensor 'Const:0' shape=(5, 5) dtype=float64>
a[0,:]
<tf.Tensor 'strided_slice:0' shape=(5,) dtype=float64>
The slice of tensor a
will reduce the original number of dimension 2
to 1
How could I just directly get the sliced with rank not changed like:?
a[0,:]
<tf.Tensor 'strided_slice:0' shape=(1,5) dtype=float64>
(tf.expand_dims(a[0,:], axis=0)
could work, but are there more direct and easy way?)
Upvotes: 4
Views: 3513
Reputation: 29972
There are at least two direct ways, quite similar to those available in NumPy (related question).
a[x:x+1]
None
: a[None, x]
a[0:1]
<tf.Tensor 'strided_slice_1:0' shape=(1, 5) dtype=float64>
Some actual tensor running shows the expected outcome.
with tf.Session() as sess:
sess.run(a[0])
sess.run(a[0:1])
sess.run(a[None, 0])
array([1., 0., 0., 0., 0.])
array([[1., 0., 0., 0., 0.]])
array([[1., 0., 0., 0., 0.]])
Upvotes: 4