user209974
user209974

Reputation: 1927

How to produce all the combinations of the elements of two or more tensors?

I would like to compute all the combinations of two or more tensors. For example, for two tensors containing resp. the values [1, 2] and [3, 4, 5], I would like to get the 6x2 tensor

[[1, 3],
 [1, 4],
 [1, 5],
 [2, 3],
 [2, 4],
 [2, 5]]

To do this, I came up with the following hack

import tensorflow as tf

def combine(x, y):
  x, y = x[:, None], y[:, None]
  x1 = tf.concat([x, tf.ones_like(x)], axis=-1)
  y1 = tf.concat([tf.ones_like(y), y], axis=-1)
  return tf.reshape(x1[:, None] * y1[None], (-1, 2))

x = tf.constant([1, 2])
y = tf.constant([3, 4, 5])
print(combine(x, y))
# tf.Tensor(
# [[1 3]
#  [1 4]
#  [1 5]
#  [2 3]
#  [2 4]
#  [2 5]], shape=(6, 2), dtype=int32)

However I am not satisfied with this solution:

Is there a more efficient and/or general way of doing this?

Upvotes: 1

Views: 833

Answers (1)

javidcf
javidcf

Reputation: 59681

You can do that easily with tf.meshgrid:

import tensorflow as tf

def combine(x, y):
  xx, yy = tf.meshgrid(x, y, indexing='ij')
  return tf.stack([tf.reshape(xx, [-1]), tf.reshape(yy, [-1])], axis=1)

x = tf.constant([1, 2])
y = tf.constant([3, 4, 5])
print(combine(x, y).numpy())
# [[1 3]
#  [1 4]
#  [1 5]
#  [2 3]
#  [2 4]
#  [2 5]]

Upvotes: 2

Related Questions