null
null

Reputation: 1217

Tensorflow: How to slice tensor with number of dimension not changed?

For example, if we have:

a = tf.constant(np.eye(5))
a
<tf.Tensor 'Const:0' shape=(5, 5) dtype=float64>
a[0,:]
<tf.Tensor 'strided_slice:0' shape=(5,) dtype=float64>

The slice of tensor a will reduce the original number of dimension 2 to 1

How could I just directly get the sliced with rank not changed like:?

a[0,:]
<tf.Tensor 'strided_slice:0' shape=(1,5) dtype=float64>

(tf.expand_dims(a[0,:], axis=0) could work, but are there more direct and easy way?)

Upvotes: 4

Views: 3513

Answers (1)

E_net4
E_net4

Reputation: 29972

There are at least two direct ways, quite similar to those available in NumPy (related question).

  1. Fetch a range on that axis of size 1: a[x:x+1]
  2. Add an axis with None: a[None, x]
a[0:1]
<tf.Tensor 'strided_slice_1:0' shape=(1, 5) dtype=float64>

Some actual tensor running shows the expected outcome.

with tf.Session() as sess:
    sess.run(a[0])
    sess.run(a[0:1])
    sess.run(a[None, 0])
array([1., 0., 0., 0., 0.])
array([[1., 0., 0., 0., 0.]])
array([[1., 0., 0., 0., 0.]])

Upvotes: 4

Related Questions