Reputation: 2140
I need to construct a matrix z
that would contain combinations of pairs of rows of a matrix x
.
x = tf.constant([[1, 3],
[2, 4],
[0, 2],
[0, 1]], dtype=tf.int32)
z=[[[1,2],
[1,0],
[1,0],
[2,0],
[2,0],
[0,0]],
[3,4],
[3,2],
[3,1],
[4,2],
[4,1],
[2,1]]]
It pairs each value with the rest of the values on that row.
I could not find any function or come up with a good idea to do that.
Update 1
So I need the final shape be 2*6*2
like the z
above.
Upvotes: 3
Views: 324
Reputation: 59711
Here is a way to do that without a loop:
import tensorflow as tf
x = tf.constant([[1, 3],
[2, 4],
[0, 2],
[0, 1]], dtype=tf.int32)
# Number of rows
n = tf.shape(x)[0]
# Grid of indices
ri = tf.range(0, n - 1)
rj = ri + 1
ii, jj = tf.meshgrid(ri, rj, indexing='ij')
# Stack together
grid = tf.stack([ii, jj], axis=-1)
# Get upper triangular part
m = ii < jj
idx = tf.boolean_mask(grid, m)
# Get values
g = tf.gather(x, idx, axis=0)
# Rearrange result
result = tf.transpose(g, [2, 0, 1])
print(result.numpy())
# [[[1 2]
# [1 0]
# [1 0]
# [2 0]
# [2 0]
# [0 0]]
#
# [[3 4]
# [3 2]
# [3 1]
# [4 2]
# [4 1]
# [2 1]]]
Upvotes: 3
Reputation: 24591
Unfortunately, it's a bit more complex than one would like using tensorflow operators only. I would go with creating the indices for all combinations with a while_loop
then use tf.gather
to collect values:
import tensorflow as tf
x = tf.constant([[1, 3],
[2, 4],
[3, 2],
[0, 1]], dtype=tf.int32)
m = tf.constant([], shape=(0,2), dtype=tf.int32)
_, idxs = tf.while_loop(
lambda i, m: i < tf.shape(x)[0] - 1,
lambda i, m: (i + 1, tf.concat([m, tf.stack([tf.tile([i], (tf.shape(x)[0] - 1 - i,)), tf.range(i + 1, tf.shape(x)[0])], axis=1)], axis=0)),
loop_vars=(0, m),
shape_invariants=(tf.TensorShape([]), tf.TensorShape([None, 2])))
z = tf.reshape(tf.transpose(tf.gather(x, idxs), (2,0,1)), (-1, 2))
# <tf.Tensor: shape=(12, 2), dtype=int32, numpy=
# array([[1, 2],
# [1, 3],
# [1, 0],
# [2, 3],
# [2, 0],
# [3, 0],
# [3, 4],
# [3, 2],
# [3, 1],
# [4, 2],
# [4, 1],
# [2, 1]])>
This should work in both TF1 and TF2.
If the length of x
is known in advance, you don't need the while_loop
and could simply precompute the indices in python then place them in a constant.
Upvotes: 2