Geralt Xu
Geralt Xu

Reputation: 41

How to Concatenate 2 Embedding layers with 'mask_zero=True' in Keras2.0?

I have two Embedding layers, one assigned mask_zero=True, the other not, as defined bellow.

a = Input(shape=[30])
b = Input(shape=[30])
emb_a = Embedding(10, 5, mask_zero=True)(a)
emb_b = Embedding(20, 5, mask_zero=False)(b)
cat = Concatenate(axis=1)([emb_a, emb_b]) # problem here
model = Model(inputs=[a, b], outputs=[cat])

When I tried to concatenate them at axis=1, I expected an output with size [None, 60, 5], but it raised an error:

ValueError: Dimension 0 in both shapes must be equal, but are 1 and 5.
Shapes are [1] and [5]. for 'concatenate_1/concat_1' (op: 'ConcatV2') with input shapes: 
[?,30,1], [?,30,5], [] and with computed input tensors: input[2] = <1>.

Why the shape of emb_a become [None, 30, 1]?Why there is another empty tensor [] fed into Concatenate?

If the two Embedding layers are both assigned mask_zero=True or both mask_zero=False, it would not raise this error. If they are concatenated at axis=2,it would not raise this error either.

My keras version is 2.0.8.

Thank you.

Upvotes: 4

Views: 3198

Answers (1)

Vikash Singh
Vikash Singh

Reputation: 14001

Because you have mask_zero=True in one case and mask_zero=False in another its causing some issue internally (which should not be happening), Maybe it's a bug and you can report it on Github.

For now 2 options that I think worked is using only one of these for both the embeddings: mask_zero=True or mask_zero=False

a = Input(shape=[30])
b = Input(shape=[30])
emb_a = Embedding(10, 5)(a)
emb_b = Embedding(20, 5)(b)
cat = Concatenate(axis=1)([emb_a, emb_b])
model = Model(inputs=[a, b], outputs=[cat])

print(model.output_shape) # (None, 60, 5)

Another approach to solve this problem is to concatenate on axis=-1

a = Input(shape=[30])
b = Input(shape=[30])
emb_a = Embedding(10, 5, mask_zero=True)(a)
emb_b = Embedding(20, 5, mask_zero=False)(b)
cat = Concatenate()([emb_a, emb_b]) # problem here
model = Model(inputs=[a, b], outputs=[cat])

print(model.output_shape) # (None, 30, 10)

Upvotes: 2

Related Questions