Sean Easter
Sean Easter

Reputation: 869

In Keras, how to pairwise concatenate two inputs of different input size?

As a motivating example, say we have an edge weight prediction problem in a network of heterogenous nodes, e.g. images and text, and want the concatenation of every possible pair of inputs. A simple example of what the data might look like:

# two inputs of different shape
x = np.array([[1, 1],
              [2, 2],
              [3, 3]])
y = np.array([[4, 4, 4],
              [5, 5, 5]])

# a predicted feature we'd like to model
z= np.array([0, 1, 1, 0, 0, 0])

joined = np.array([[1, 1, 4, 4, 4], 
                   [1, 1, 5, 5, 5], 
                   [2, 2, 4, 4, 4], 
                   [2, 2, 5, 5, 5], 
                   [3, 3, 4, 4, 4], 
                   [3, 3, 5, 5, 5]])
some_model.fit(inputs=[x,y], outputs=z)

And an example model (shown with dense layers, but in spirit this could be any layer or sequence of layers):

enter image description here

Concatenation is simple enough per this other answer, and inputs needn't be of the same size, but I'm not sure whether and how one can create this kind of model.

Is there a straightforward way to achieve this in Keras?

Upvotes: 1

Views: 638

Answers (1)

Vlad
Vlad

Reputation: 8585

Using tf.tile(), tf.reshape() and tf.concat():

import tensorflow as tf
import numpy as np

x_data = np.array([[1, 1],
                   [2, 2],
                   [3, 3]], dtype=np.float32)
y_data = np.array([[4, 4, 4],
                   [5, 5, 5]], dtype=np.float32)

x = tf.placeholder(tf.float32, shape=(None, 2))
y = tf.placeholder(tf.float32, shape=(None, 3))
xshape = tf.shape(x)
yshape = tf.shape(y)
newshape = (xshape[0] * yshape[0], xshape[1] + yshape[1])

xres = tf.tile(x, multiples=[1, yshape[0]])
xres = tf.reshape(xres, [newshape[0], xshape[1]])
# `x` is now: [[1. 1.]
#              [1. 1.]
#              [2. 2.]
#              [2. 2.]
#              [3. 3.]
#              [3. 3.]]
yres = tf.tile(y, multiples=[xshape[0], 1])
# `y` is now: [[4. 4. 4.]
#              [5. 5. 5.]
#              [4. 4. 4.]
#              [5. 5. 5.]
#              [4. 4. 4.]
#              [5. 5. 5.]]
res = tf.concat([xres, yres], axis=1) # <-- this is your result
with tf.Session() as sess:
    evaled = res.eval({x:x_data, y:y_data})
    print(evaled)
# [[1. 1. 4. 4. 4.]
#  [1. 1. 5. 5. 5.]
#  [2. 2. 4. 4. 4.]
#  [2. 2. 5. 5. 5.]
#  [3. 3. 4. 4. 4.]
#  [3. 3. 5. 5. 5.]]

Upvotes: 2

Related Questions