Reputation: 1927
I would like to compute all the combinations of two or more tensors. For example, for two tensors containing resp. the values [1, 2]
and [3, 4, 5]
, I would like to get the 6x2
tensor
[[1, 3],
[1, 4],
[1, 5],
[2, 3],
[2, 4],
[2, 5]]
To do this, I came up with the following hack
import tensorflow as tf
def combine(x, y):
x, y = x[:, None], y[:, None]
x1 = tf.concat([x, tf.ones_like(x)], axis=-1)
y1 = tf.concat([tf.ones_like(y), y], axis=-1)
return tf.reshape(x1[:, None] * y1[None], (-1, 2))
x = tf.constant([1, 2])
y = tf.constant([3, 4, 5])
print(combine(x, y))
# tf.Tensor(
# [[1 3]
# [1 4]
# [1 5]
# [2 3]
# [2 4]
# [2 5]], shape=(6, 2), dtype=int32)
However I am not satisfied with this solution:
Is there a more efficient and/or general way of doing this?
Upvotes: 1
Views: 833
Reputation: 59681
You can do that easily with tf.meshgrid
:
import tensorflow as tf
def combine(x, y):
xx, yy = tf.meshgrid(x, y, indexing='ij')
return tf.stack([tf.reshape(xx, [-1]), tf.reshape(yy, [-1])], axis=1)
x = tf.constant([1, 2])
y = tf.constant([3, 4, 5])
print(combine(x, y).numpy())
# [[1 3]
# [1 4]
# [1 5]
# [2 3]
# [2 4]
# [2 5]]
Upvotes: 2