Rylan Schaeffer
Rylan Schaeffer

Reputation: 2673

Tensorflow: How to slice/gather all possible configurations?

I have a tensor with shape (batch size, sequence length, 2, N, K) (in my particular case, the 2 represents the (x, y) spatial position). N represents N variables and K is the number of values each variable can take.

I want to generate a tensor with shape (batch size, sequence length, 2, N, K^N), where K^N arises from all possible configurations of each of the N "variables" taking on each of their K possible values.

How can I efficiently do this in Tensorflow using slicing or gathering?

Let me offer a pictorial example. For the purposes of illustration, I'm going to omit the 2 leading dimensions of batch size and sequence length.

Suppose this is a 3D tensor x of shape (2, N=3, K=4):

enter image description here

The first configuration would be (abusing notation slightly) to take a slice like x[:, (0, 1, 2), (0, 0, 0)]; here, the N=3 variables are all taking on their first values. The second configuration would be to take a slice like x[:, (0, 1, 2), (0, 0, 1)]; here, the first two variables are taking on their first values and the third variable is taking on its second value. This continues on, up to the 4^3=64 possible configurations, with the last being x[:, (0, 1, 2), (3, 3, 3)].

If I stacked all these up, the result would be a tensor with shape (2, 3, 4^3).

Upvotes: 1

Views: 125

Answers (1)

AloneTogether
AloneTogether

Reputation: 26708

IIUC and if it is still relevant, here's one way you can solve your problem purely with Tensorflow (note I worked with the 3D tensor and omitted the first two leading dimensions):

import tensorflow as tf

tf.random.set_seed(111)
x = tf.random.uniform((2, 3, 4), maxval=15, dtype=tf.int32)
x_shape = tf.shape(x)
print('x -->', x, '\n')

first_dim = tf.range(x_shape[0])
second_dim = tf.range(x_shape[1])
second_dim = tf.repeat(second_dim, tf.shape(first_dim)[0])
combination_range = tf.range(x_shape[-1])
xx, yy, zz = tf.meshgrid(combination_range, combination_range, combination_range, indexing='ij')
combinations = tf.stack([tf.reshape(xx, [-1]), tf.reshape(yy, [-1]), tf.reshape(zz, [-1])], axis=1)
print('combinations -->', combinations, '\n')

combinations = tf.reshape(tf.tile(combinations, [1, tf.shape(first_dim)[0]]), [-1])
first_dim = tf.tile(first_dim, [tf.shape(combinations)[0] // tf.shape(first_dim)[0]])
second_dim = tf.tile(second_dim, [tf.shape(combinations)[0] // tf.shape(second_dim)[0]])

result = tf.gather_nd(x, tf.transpose(tf.stack([first_dim, second_dim, combinations])))
result = tf.reshape(result, (x_shape[0], x_shape[1], x_shape[-1]**x_shape[1]))
print('final result -->', result)
x --> tf.Tensor(
[[[ 5 14  1 14]
  [ 2  1 12  3]
  [ 2  5  7 10]]

 [[ 0  9  0 12]
  [12 11  0  1]
  [ 2  6  1 12]]], shape=(2, 3, 4), dtype=int32) 

combinations --> tf.Tensor(
[[0 0 0]
 [0 0 1]
 [0 0 2]
 [0 0 3]
 [0 1 0]
 [0 1 1]
 [0 1 2]
 [0 1 3]
 [0 2 0]
 [0 2 1]
 [0 2 2]
 [0 2 3]
 [0 3 0]
 [0 3 1]
 [0 3 2]
 [0 3 3]
 [1 0 0]
 [1 0 1]
 [1 0 2]
 [1 0 3]
 [1 1 0]
 [1 1 1]
 [1 1 2]
 [1 1 3]
 [1 2 0]
 [1 2 1]
 [1 2 2]
 [1 2 3]
 [1 3 0]
 [1 3 1]
 [1 3 2]
 [1 3 3]
 [2 0 0]
 [2 0 1]
 [2 0 2]
 [2 0 3]
 [2 1 0]
 [2 1 1]
 [2 1 2]
 [2 1 3]
 [2 2 0]
 [2 2 1]
 [2 2 2]
 [2 2 3]
 [2 3 0]
 [2 3 1]
 [2 3 2]
 [2 3 3]
 [3 0 0]
 [3 0 1]
 [3 0 2]
 [3 0 3]
 [3 1 0]
 [3 1 1]
 [3 1 2]
 [3 1 3]
 [3 2 0]
 [3 2 1]
 [3 2 2]
 [3 2 3]
 [3 3 0]
 [3 3 1]
 [3 3 2]
 [3 3 3]], shape=(64, 3), dtype=int32) 

final result --> tf.Tensor(
[[[ 5  0  2 12  2  2  5  0  1 12  2  6  5  0 12 12  2  1  5  0  3 12  2
   12  5  9  2 12  5  2  5  9  1 12  5  6  5  9 12 12  5  1  5  9  3 12
    5 12  5  0  2 12  7  2  5  0  1 12  7  6  5  0 12 12]
  [ 7  1  5  0  3 12  7 12  5 12  2 12 10  2  5 12  1 12 10  6  5 12 12
   12 10  1  5 12  3 12 10 12 14  0  2 11  2  2 14  0  1 11  2  6 14  0
   12 11  2  1 14  0  3 11  2 12 14  9  2 11  5  2 14  9]
  [ 1 11  5  6 14  9 12 11  5  1 14  9  3 11  5 12 14  0  2 11  7  2 14
    0  1 11  7  6 14  0 12 11  7  1 14  0  3 11  7 12 14 12  2 11 10  2
   14 12  1 11 10  6 14 12 12 11 10  1 14 12  3 11 10 12]]

 [[ 1  0  2  0  2  2  1  0  1  0  2  6  1  0 12  0  2  1  1  0  3  0  2
   12  1  9  2  0  5  2  1  9  1  0  5  6  1  9 12  0  5  1  1  9  3  0
    5 12  1  0  2  0  7  2  1  0  1  0  7  6  1  0 12  0]
  [ 7  1  1  0  3  0  7 12  1 12  2  0 10  2  1 12  1  0 10  6  1 12 12
    0 10  1  1 12  3  0 10 12 14  0  2  1  2  2 14  0  1  1  2  6 14  0
   12  1  2  1 14  0  3  1  2 12 14  9  2  1  5  2 14  9]
  [ 1  1  5  6 14  9 12  1  5  1 14  9  3  1  5 12 14  0  2  1  7  2 14
    0  1  1  7  6 14  0 12  1  7  1 14  0  3  1  7 12 14 12  2  1 10  2
   14 12  1  1 10  6 14 12 12  1 10  1 14 12  3  1 10 12]]], shape=(2, 3, 64), dtype=int32)

For arbitary N, try this:

import tensorflow as tf

tf.random.set_seed(111)
x = tf.random.uniform((2, 6, 4), maxval=15, dtype=tf.int32)
x_shape = tf.shape(x)
print('x -->', x, '\n')

first_dim = tf.range(x_shape[0])
second_dim = tf.range(x_shape[1])
second_dim = tf.repeat(second_dim, tf.shape(first_dim)[0])
combination_range = tf.range(x_shape[-1])
outputs = tf.meshgrid(*tuple(tf.unstack(tf.repeat(combination_range[tf.newaxis, ...], x_shape[1], axis=0), axis=0)), indexing='ij')
combinations = tf.stack([tf.reshape(o, [-1]) for o in outputs], axis=1)
print('combinations -->', combinations, '\n')

combinations = tf.reshape(tf.tile(combinations, [1, tf.shape(first_dim)[0]]), [-1])
first_dim = tf.tile(first_dim, [tf.shape(combinations)[0] // tf.shape(first_dim)[0]])
second_dim = tf.tile(second_dim, [tf.shape(combinations)[0] // tf.shape(second_dim)[0]])

result = tf.gather_nd(x, tf.transpose(tf.stack([first_dim, second_dim, combinations])))
result = tf.reshape(result, (x_shape[0], x_shape[1], x_shape[-1]**x_shape[1]))
print('final result -->', result)

Upvotes: 1

Related Questions