fsociety
fsociety

Reputation: 1870

Keras: weighted average of embedding layers

I am currently implementing a sequence model in Keras and want to utilize two (or more) pre-trained word embeddings. Currently, my way to go is to average both embedding matrices before passing it to Keras. However, I want to do weighted averaging instead.

Of course I can optimize the weights as a form of hyperparameter, but am wondering for a way to do it as part of the model, e.g., through a softmax dense layer for weighting. Ideally, I would have two options, the first fits the weights for merging the whole matrices, and the second has weights on a word level for merging the vectors. I have not figured out yet how to do it properly though and would be happy for suggestions.

Upvotes: 0

Views: 2509

Answers (1)

Daniel Möller
Daniel Möller

Reputation: 86620

For averaging, you need that both embeddings have the same output size.

You can simply stack them in the last dimension and pass them to a Dense layer:

inputs = Input((length,))

embedding1 = getEmbeddingFor(inputs)
embedding2 = getEmbeddingFor2(inputs)

#stacks into shape (batch, length, embedding_size, 2)
stacked = Lambda(lambda x: K.stack([x[0],x[1]], axis=-1))([embedding1,embedding2])

#weights to (batch, length, embedding_size, 1)
weighted = Dense(1, use_bias=False)(stacked)

#removes the last dimension
weighted = Lambda(lambda x: K.squeeze(x, axis=-1))(weighted)

Or, if you don't mind having a lot of weights instead of simply two weights, and if the sizes are different, and if you want a lot more intelligence in the weighting, you can do a very simple concatenation:

weighted = Concatenate()([embedding1, embedding2])   
weighted = Dense(similarToSize)(weighted)

The second approach might not sound like weighting, but if you think of it for a while, two embeddings might not have the same meanings for the same positions, and averaging two values of different nature might not result in great things. (But, of course, neural networks are obscure and only testing can prove this statement)

Upvotes: 1

Related Questions