Reputation: 758
I'm trying to extract all the possible permutations from a Tensor along a specific axis. My input is a [B, S, L]
tensor (B batches of S vectors of length L) and I want to extract all the possible permutations among these vectors (the S! permutations) namely a [B, S!, S, L]
Tensor as output.
That's what I tried for now but I'm struggling getting the right output shape. I think my mistake might be that I'm creating a batch_range, but I should create a permutation_range as well.
import tensorflow as tf
import numpy as np
from itertools import permutations
S = 3
B = 5
L = 10
input = tf.constant(np.random.randn(B, S, L))
perms = list(permutations(range(S))) # ex with 3: [0, 1, 2], [0, 2 ,1], [1, 0, 2], [1, 2, 0], [2, 1, 0], [2, 0, 1]
length_perm = len(perms)
perms = tf.reshape(tf.constant(perms), [1, length_perm, S, 1])
perms = tf.tile(perms, [B, 1, 1, 1])
batch_range = tf.tile(tf.reshape(tf.range(B, dtype=tf.int32), shape=[B, 1, 1, 1]), [1, length_perm, S, 1])
indicies = tf.concat([batch_range, perms], axis=3)
permutations = tf.gather_nd(tf.tile(tf.reshape(input, [B, 1, S, L]), [1, length_perm, 1, 1]), indicies) #
# I get a [ B, P, S, S, L] instead of the desired [B, P, S, L]
I posted one possible 'solution' just below, but I think there is still a problem with this one. I tested it, and if B>1 it's not going very well.
Upvotes: 2
Views: 1073
Reputation: 51
I know this is late, but I came across the same problem and wanted to share my solution. I also generate the permutation list. Then I build a permutation tensor from it. Then I multiply it with the tensor. It doesn't use tf.gather_nd(), but a clean matrix multiplcation.
import tensorflow as tf
import numpy as np
from itertools import permutations
B = 5 # batch size
S = 3 # here permutations
L = 10 # length of the S vecors
data = tf.constant(np.random.randn(B, S, L))
perms = list(permutations(range(S))) # ex with 3: [0, 1, 2], [0, 2 ,1], [1, 0, 2],[1, 2, 0], [2, 1, 0], [2, 0, 1]
N= len(perms)
# from here new code:
eye = tf.eye(S,dtype=tf.int32) # creates eye matrix of [S x S]
# now we cast the eye matrix and permutation matrix, so that they give a [N,S,S] matrix, which are basically N eye matrcices with the permutation indices on the diagonal
perm_mat = tf.constant(np.eye(S)[np.array(perms)],dtype= tf.float64)
# this can be now multiplied to the tensor and gives the permutated output. We just need to broadcast the permutation dimension here
res = tf.linalg.matmul(perm_mat, data[:,tf.newaxis,...])
print(res)
Upvotes: 0
Reputation: 758
I just found an answer I think, please correct me if you think I'm wrong or if there is an easier way to do this:
import tensorflow as tf
import numpy as np
from itertools import permutations
S = 3
B = 5
L = 10
input = tf.constant(np.random.randn(B, S, L))
perms = list(permutations(range(S))) # ex with 3: [0, 1, 2], [0, 2 ,1], [1, 0, 2], [1, 2, 0], [2, 1, 0], [2, 0, 1]
length_perm = len(perms)
perms = tf.reshape(tf.constant(perms), [1, length_perm, S, 1])
perms = tf.tile(perms, [B, 1, 1, 1])
batch_range = tf.tile(tf.reshape(tf.range(B, dtype=tf.int32), shape=[B, 1, 1, 1]), [1, length_perm, S, 1])
perm_range = tf.tile(tf.reshape(tf.range(length_perm, dtype=tf.int32), shape=[1, length_perm, 1, 1]), [B, 1, S, 1])
indicies = tf.concat([batch_range, perm_range, perms], axis=3)
permutations = tf.gather_nd(tf.tile(tf.reshape(input, [B, 1, S, L]), [1, length_perm, 1, 1]), indicies) #
print permutations
Upvotes: 1