BlueMango
BlueMango

Reputation: 503

Keras Lambda layers gives ValueError: Input 0 is incompatible with layer xxx: expected min_ndim=3, found ndim=2

When I add lambda layer to my sequential model it gives ValueError: Input 0 is incompatible with ....

For this model I get ValueError: Input 0 is incompatible with layer flatten_1: expected min_ndim=3, found ndim=2

model1 = Sequential()
model1.add(Embedding(max_words, embedding_dim, input_length=maxlen))
model1.add(Lambda(lambda x: mean(x, axis=1)))
model1.add(Flatten())
model1.add(Bidirectional(LSTM(32)))
model1.add(Dropout(0.6))
model1.add(Dense(2))

If I remove the Flatten() I get : ValueError: Input 0 is incompatible with layer bidirectional_1: expected ndim=3, found ndim=2. However, without lambda layer the model works fine.

Any idea on What is causing this problem and how can I solve this will be appreciated. Thanks

Upvotes: 0

Views: 729

Answers (1)

Pedro Marques
Pedro Marques

Reputation: 2682

The following generates a graph that seems correct:

from tensorflow.python import keras
from keras.models import Sequential
from keras.layers import *
import numpy as np

max_words = 1000
embedding_dim = 300
maxlen = 10

def mean(x, axis):
  """mean
     input_shape=(batch_size, time_slots, ndims)
     depending on the axis mean will:
       0: compute mean value for a batch and reduce batch size to 1
       1: compute mean value across time slots and reduce time_slots to 1
       2: compute mean value across ndims an reduce dims to 1.
  """
  return K.mean(x, axis=axis, keepdims=True)

model1 = Sequential()
model1.add(Embedding(max_words, embedding_dim, input_length=maxlen))
model1.add(Lambda(lambda x: mean(x, axis=1)))
model1.add(Bidirectional(LSTM(32)))
model1.add(Dropout(0.6))
model1.add(Dense(2))
model1.compile('sgd', 'mse')
model1.summary()

The Embedding layer uses 3 dimensions (batch_size, maxlen, embedding_dim). The LSTM layer is also expecting 3 dimensions. So the lambda should return a shape that is compatible or you need to reshape. Here K.mean offers a convenient param (keepdims) which help us do that.

Upvotes: 1

Related Questions