Anthony D'Amato
Anthony D'Amato

Reputation: 758

Getting all possible permutations along an axis with tf.gather_nd

I'm trying to extract all the possible permutations from a Tensor along a specific axis. My input is a [B, S, L] tensor (B batches of S vectors of length L) and I want to extract all the possible permutations among these vectors (the S! permutations) namely a [B, S!, S, L] Tensor as output. That's what I tried for now but I'm struggling getting the right output shape. I think my mistake might be that I'm creating a batch_range, but I should create a permutation_range as well.

import tensorflow as tf
import numpy as np
from itertools import permutations

S = 3
B = 5
L = 10

input = tf.constant(np.random.randn(B, S, L))

perms = list(permutations(range(S))) # ex with 3: [0, 1, 2], [0, 2 ,1], [1, 0, 2], [1, 2, 0], [2, 1, 0], [2, 0, 1]
length_perm = len(perms)
perms = tf.reshape(tf.constant(perms), [1, length_perm, S, 1])
perms = tf.tile(perms, [B, 1, 1, 1])

batch_range = tf.tile(tf.reshape(tf.range(B, dtype=tf.int32), shape=[B, 1, 1, 1]), [1, length_perm, S, 1])
indicies = tf.concat([batch_range, perms], axis=3)

permutations = tf.gather_nd(tf.tile(tf.reshape(input, [B, 1, S, L]), [1, length_perm, 1, 1]), indicies) # 
# I get a [ B, P, S, S, L] instead of the desired [B, P, S, L]

I posted one possible 'solution' just below, but I think there is still a problem with this one. I tested it, and if B>1 it's not going very well.

Upvotes: 2

Views: 1073

Answers (2)

meister hubert
meister hubert

Reputation: 51

I know this is late, but I came across the same problem and wanted to share my solution. I also generate the permutation list. Then I build a permutation tensor from it. Then I multiply it with the tensor. It doesn't use tf.gather_nd(), but a clean matrix multiplcation.

import tensorflow as tf
import numpy as np
from itertools import permutations

B = 5 # batch size
S = 3 # here permutations
L = 10 # length of the S vecors

data = tf.constant(np.random.randn(B, S, L))

perms = list(permutations(range(S))) # ex with 3: [0, 1, 2], [0, 2 ,1], [1, 0, 2],[1, 2, 0], [2, 1, 0], [2, 0, 1]
N= len(perms)

# from here new code:
eye = tf.eye(S,dtype=tf.int32) # creates eye matrix of [S x S]
# now we cast the eye matrix and permutation matrix, so that they give a [N,S,S] matrix, which are basically N eye matrcices with the permutation indices on the diagonal
perm_mat = tf.constant(np.eye(S)[np.array(perms)],dtype= tf.float64)
# this can be now multiplied to the tensor and gives the permutated output. We just need to broadcast the permutation dimension here
res = tf.linalg.matmul(perm_mat, data[:,tf.newaxis,...])
print(res)

Upvotes: 0

Anthony D'Amato
Anthony D'Amato

Reputation: 758

I just found an answer I think, please correct me if you think I'm wrong or if there is an easier way to do this:

import tensorflow as tf
import numpy as np
from itertools import permutations

S = 3
B = 5
L = 10

input = tf.constant(np.random.randn(B, S, L))

perms = list(permutations(range(S))) # ex with 3: [0, 1, 2], [0, 2 ,1], [1, 0, 2], [1, 2, 0], [2, 1, 0], [2, 0, 1]
length_perm = len(perms)
perms = tf.reshape(tf.constant(perms), [1, length_perm, S, 1])
perms = tf.tile(perms, [B, 1, 1, 1])

batch_range = tf.tile(tf.reshape(tf.range(B, dtype=tf.int32), shape=[B, 1, 1, 1]), [1, length_perm, S, 1])
perm_range = tf.tile(tf.reshape(tf.range(length_perm, dtype=tf.int32), shape=[1, length_perm, 1, 1]), [B, 1, S, 1])
indicies = tf.concat([batch_range, perm_range, perms], axis=3)

permutations = tf.gather_nd(tf.tile(tf.reshape(input, [B, 1, S, L]), [1, length_perm, 1, 1]), indicies) # 
print permutations

Upvotes: 1

Related Questions