dchurchwell
dchurchwell

Reputation: 25

Concatenation of two Embedding layers in keras throws ValueError

I'm trying to build an NLP model that uses two different types of embedding, but I can't get the concatenation layer to function properly.

From what I can tell, the layer output shapes are correct, and they should be able to concatenate along the last axis. Am I calling something improperly?

Code sample:

from tensorflow.keras import layers

x_embed = layers.Embedding(90000, 100, input_length=9000)
feat_embed = layers.Embedding(90000, 8, input_length=9000)

layerlist = [x_embed, feat_embed]
concat = layers.Concatenate()(layerlist)

Error:

Exception has occurred: ValueError

A Concatenate layer should be called on a list of at least 2 inputs

Upvotes: 0

Views: 497

Answers (1)

Zabir Al Nazi Nabil
Zabir Al Nazi Nabil

Reputation: 11198

You need to have inputs in your model, also need to specify correct concatenation axis.

from tensorflow.keras import layers
from tensorflow.keras import models

ip1 = layers.Input((9000))
ip2 = layers.Input((9000))
x_embed = layers.Embedding(90000, 100, input_length=9000)(ip1)
feat_embed = layers.Embedding(90000, 8, input_length=9000)(ip2)

layerlist = [x_embed, feat_embed]
concat = layers.Concatenate(axis = -1)(layerlist)

model = models.Model([ip1, ip2], concat)
model.summary()
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_5 (InputLayer)            [(None, 9000)]       0                                            
__________________________________________________________________________________________________
input_6 (InputLayer)            [(None, 9000)]       0                                            
__________________________________________________________________________________________________
embedding_6 (Embedding)         (None, 9000, 100)    9000000     input_5[0][0]                    
__________________________________________________________________________________________________
embedding_7 (Embedding)         (None, 9000, 8)      720000      input_6[0][0]                    
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 9000, 108)    0           embedding_6[0][0]                
                                                                 embedding_7[0][0]                
==================================================================================================
Total params: 9,720,000
Trainable params: 9,720,000
Non-trainable params: 0

Upvotes: 1

Related Questions