anthonya
anthonya

Reputation: 565

Two inputs to one model in Keras

Is it possible in Keras to feed both an image and a vector of values as inputs to one model? If yes, how?

What I want is to create a CNN with an image and a vector of 6 values on the input.

The output is the vector of 3 values.

Upvotes: 20

Views: 26026

Answers (1)

sdcbr
sdcbr

Reputation: 7129

Yes, please have a look at Keras' Functional API for many examples on how to build models with multiple inputs.

Your code will look something like this, where you will probably want to pass the image through a convolutional layer, flatten the output and concatenate it with your vector input:

from keras.layers import Input, Concatenate, Conv2D, Flatten, Dense
from keras.models import Model

# Define two input layers
image_input = Input((32, 32, 3))
vector_input = Input((6,))

# Convolution + Flatten for the image
conv_layer = Conv2D(32, (3,3))(image_input)
flat_layer = Flatten()(conv_layer)

# Concatenate the convolutional features and the vector input
concat_layer= Concatenate()([vector_input, flat_layer])
output = Dense(3)(concat_layer)

# define a model with a list of two inputs
model = Model(inputs=[image_input, vector_input], outputs=output)

This will give you a model with the following specs:

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_8 (InputLayer)            (None, 32, 32, 3)    0                                            
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 30, 30, 32)   896         input_8[0][0]                    
__________________________________________________________________________________________________
input_9 (InputLayer)            (None, 6)            0                                            
__________________________________________________________________________________________________
flatten_3 (Flatten)             (None, 28800)        0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 28806)        0           input_9[0][0]                    
                                                                 flatten_3[0][0]                  
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 3)            86421       concatenate_3[0][0]              
==================================================================================================
Total params: 87,317
Trainable params: 87,317
Non-trainable params: 0

Another way to visualize it is through Keras' visualization utilities:

enter image description here

Upvotes: 37

Related Questions