Reputation: 2673
I have a tensor with shape (batch size, sequence length, 2, N, K)
(in my particular case, the 2 represents the (x, y) spatial position). N
represents N variables and K
is the number of values each variable can take.
I want to generate a tensor with shape (batch size, sequence length, 2, N, K^N)
, where K^N
arises from all possible configurations of each of the N
"variables" taking on each of their K
possible values.
How can I efficiently do this in Tensorflow using slicing or gathering?
Let me offer a pictorial example. For the purposes of illustration, I'm going to omit the 2 leading dimensions of batch size and sequence length.
Suppose this is a 3D tensor x
of shape (2, N=3, K=4)
:
The first configuration would be (abusing notation slightly) to take a slice like x[:, (0, 1, 2), (0, 0, 0)]
; here, the N=3
variables are all taking on their first values. The second configuration would be to take a slice like x[:, (0, 1, 2), (0, 0, 1)]
; here, the first two variables are taking on their first values and the third variable is taking on its second value. This continues on, up to the 4^3=64 possible configurations, with the last being x[:, (0, 1, 2), (3, 3, 3)]
.
If I stacked all these up, the result would be a tensor with shape (2, 3, 4^3)
.
Upvotes: 1
Views: 125
Reputation: 26708
IIUC and if it is still relevant, here's one way you can solve your problem purely with Tensorflow
(note I worked with the 3D tensor and omitted the first two leading dimensions):
import tensorflow as tf
tf.random.set_seed(111)
x = tf.random.uniform((2, 3, 4), maxval=15, dtype=tf.int32)
x_shape = tf.shape(x)
print('x -->', x, '\n')
first_dim = tf.range(x_shape[0])
second_dim = tf.range(x_shape[1])
second_dim = tf.repeat(second_dim, tf.shape(first_dim)[0])
combination_range = tf.range(x_shape[-1])
xx, yy, zz = tf.meshgrid(combination_range, combination_range, combination_range, indexing='ij')
combinations = tf.stack([tf.reshape(xx, [-1]), tf.reshape(yy, [-1]), tf.reshape(zz, [-1])], axis=1)
print('combinations -->', combinations, '\n')
combinations = tf.reshape(tf.tile(combinations, [1, tf.shape(first_dim)[0]]), [-1])
first_dim = tf.tile(first_dim, [tf.shape(combinations)[0] // tf.shape(first_dim)[0]])
second_dim = tf.tile(second_dim, [tf.shape(combinations)[0] // tf.shape(second_dim)[0]])
result = tf.gather_nd(x, tf.transpose(tf.stack([first_dim, second_dim, combinations])))
result = tf.reshape(result, (x_shape[0], x_shape[1], x_shape[-1]**x_shape[1]))
print('final result -->', result)
x --> tf.Tensor(
[[[ 5 14 1 14]
[ 2 1 12 3]
[ 2 5 7 10]]
[[ 0 9 0 12]
[12 11 0 1]
[ 2 6 1 12]]], shape=(2, 3, 4), dtype=int32)
combinations --> tf.Tensor(
[[0 0 0]
[0 0 1]
[0 0 2]
[0 0 3]
[0 1 0]
[0 1 1]
[0 1 2]
[0 1 3]
[0 2 0]
[0 2 1]
[0 2 2]
[0 2 3]
[0 3 0]
[0 3 1]
[0 3 2]
[0 3 3]
[1 0 0]
[1 0 1]
[1 0 2]
[1 0 3]
[1 1 0]
[1 1 1]
[1 1 2]
[1 1 3]
[1 2 0]
[1 2 1]
[1 2 2]
[1 2 3]
[1 3 0]
[1 3 1]
[1 3 2]
[1 3 3]
[2 0 0]
[2 0 1]
[2 0 2]
[2 0 3]
[2 1 0]
[2 1 1]
[2 1 2]
[2 1 3]
[2 2 0]
[2 2 1]
[2 2 2]
[2 2 3]
[2 3 0]
[2 3 1]
[2 3 2]
[2 3 3]
[3 0 0]
[3 0 1]
[3 0 2]
[3 0 3]
[3 1 0]
[3 1 1]
[3 1 2]
[3 1 3]
[3 2 0]
[3 2 1]
[3 2 2]
[3 2 3]
[3 3 0]
[3 3 1]
[3 3 2]
[3 3 3]], shape=(64, 3), dtype=int32)
final result --> tf.Tensor(
[[[ 5 0 2 12 2 2 5 0 1 12 2 6 5 0 12 12 2 1 5 0 3 12 2
12 5 9 2 12 5 2 5 9 1 12 5 6 5 9 12 12 5 1 5 9 3 12
5 12 5 0 2 12 7 2 5 0 1 12 7 6 5 0 12 12]
[ 7 1 5 0 3 12 7 12 5 12 2 12 10 2 5 12 1 12 10 6 5 12 12
12 10 1 5 12 3 12 10 12 14 0 2 11 2 2 14 0 1 11 2 6 14 0
12 11 2 1 14 0 3 11 2 12 14 9 2 11 5 2 14 9]
[ 1 11 5 6 14 9 12 11 5 1 14 9 3 11 5 12 14 0 2 11 7 2 14
0 1 11 7 6 14 0 12 11 7 1 14 0 3 11 7 12 14 12 2 11 10 2
14 12 1 11 10 6 14 12 12 11 10 1 14 12 3 11 10 12]]
[[ 1 0 2 0 2 2 1 0 1 0 2 6 1 0 12 0 2 1 1 0 3 0 2
12 1 9 2 0 5 2 1 9 1 0 5 6 1 9 12 0 5 1 1 9 3 0
5 12 1 0 2 0 7 2 1 0 1 0 7 6 1 0 12 0]
[ 7 1 1 0 3 0 7 12 1 12 2 0 10 2 1 12 1 0 10 6 1 12 12
0 10 1 1 12 3 0 10 12 14 0 2 1 2 2 14 0 1 1 2 6 14 0
12 1 2 1 14 0 3 1 2 12 14 9 2 1 5 2 14 9]
[ 1 1 5 6 14 9 12 1 5 1 14 9 3 1 5 12 14 0 2 1 7 2 14
0 1 1 7 6 14 0 12 1 7 1 14 0 3 1 7 12 14 12 2 1 10 2
14 12 1 1 10 6 14 12 12 1 10 1 14 12 3 1 10 12]]], shape=(2, 3, 64), dtype=int32)
For arbitary N
, try this:
import tensorflow as tf
tf.random.set_seed(111)
x = tf.random.uniform((2, 6, 4), maxval=15, dtype=tf.int32)
x_shape = tf.shape(x)
print('x -->', x, '\n')
first_dim = tf.range(x_shape[0])
second_dim = tf.range(x_shape[1])
second_dim = tf.repeat(second_dim, tf.shape(first_dim)[0])
combination_range = tf.range(x_shape[-1])
outputs = tf.meshgrid(*tuple(tf.unstack(tf.repeat(combination_range[tf.newaxis, ...], x_shape[1], axis=0), axis=0)), indexing='ij')
combinations = tf.stack([tf.reshape(o, [-1]) for o in outputs], axis=1)
print('combinations -->', combinations, '\n')
combinations = tf.reshape(tf.tile(combinations, [1, tf.shape(first_dim)[0]]), [-1])
first_dim = tf.tile(first_dim, [tf.shape(combinations)[0] // tf.shape(first_dim)[0]])
second_dim = tf.tile(second_dim, [tf.shape(combinations)[0] // tf.shape(second_dim)[0]])
result = tf.gather_nd(x, tf.transpose(tf.stack([first_dim, second_dim, combinations])))
result = tf.reshape(result, (x_shape[0], x_shape[1], x_shape[-1]**x_shape[1]))
print('final result -->', result)
Upvotes: 1