Masood Delfarah
Masood Delfarah

Reputation: 697

Reverse order of some elements in Tensorflow

Say I have a tensor DATA of shape (M, N, 2). I also have another tensor IND of shape (N) consisting of zeros and ones.

If IND(i)==1 then DATA(:,i,0) and DATA(:,i,1) have to swap. If IND(i)==0 they won't swap.

How can I do this? I know that this can be done via tf.gather_nd, but I have no idea how.

Upvotes: 1

Views: 810

Answers (2)

Amir
Amir

Reputation: 16587

Here is one possible solution with tf.equal, tf.where, tf.scater_nd_update, tf.gather_nd and tf.reverse_v2:

data = tf.Variable([[[1, 2],
                     [2, 3],
                     [3, 4],
                     [4, 5],
                     [5, 6]]])  # shape=(1,5,2)

# reverse elements where ind is 1
ind = tf.constant([1, 0, 1, 0, 1])  # shape(5,)

cond = tf.where(tf.equal([ind], 1))
match_data = tf.gather_nd(data, cond)
rev_match_data = tf.reverse_v2(match_data, axis=[-1])
data = tf.scatter_nd_update(data, cond, rev_match_data)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(data))
    #[[[2 1]
    # [2 3]
    # [4 3]
    # [4 5]
    # [6 5]]]

Upvotes: 2

tomkot
tomkot

Reputation: 956

One way which does not use tf.gather_ind is as follows. The idea is to build DATA1, which is DATA with all possible swaps (i.e. the result of swapping if IND had been a vector of 1s), and use masks to choose the correct values from either Data or Data1 depending on whether a swap is needed or not.

DATA1 = tf.concat([tf.reshape(DATA[:,:,1], [M, N, 1]), tf.reshape(DATA[:,:,0], [M, N, 1])], axis = 2)

Mask1 = tf.cast(tf.reshape(IND, [1, N, 1]), tf.float64)
Mask0 = 1 - Mask1

Res = tf.multiply(Mask0, DATA) + tf.multiply(Mask1, DATA1)

Upvotes: 1

Related Questions