hiteshn97
hiteshn97

Reputation: 90

Attention on top of LSTM Keras

I was training an LSTM Model using Keras and wanted to add Attention on top of it. I am new to Keras, and Attention. From link How to add an attention mechanism in keras? I learnt how I could add attention over my LSTM Layer and made a model like this

print('Defining a Simple Keras Model...')
lstm_model=Sequential()  # or Graph 
lstm_model.add(Embedding(output_dim=300,input_dim=n_symbols,mask_zero=True,
                    weights=[embedding_weights],input_length=input_length))  

# Adding Input Length
lstm_model.add(Bidirectional(LSTM(300)))
lstm_model.add(Dropout(0.3))
lstm_model.add(Dense(1,activation='sigmoid'))

# compute importance for each step
attention=Dense(1, activation='tanh')
attention=Flatten()
attention=Activation('softmax')
attention=RepeatVector(64)
attention=Permute([2, 1])


sent_representation=keras.layers.Add()([lstm_model,attention])
sent_representation=Lambda(lambda xin: K.sum(xin, axis=-2),output_shape=(64))(sent_representation)

sent_representation.add(Dense(1,activation='sigmoid'))

rms_prop=RMSprop(lr=0.001,rho=0.9,epsilon=None,decay=0.0)
adam = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
print('Compiling the Model...')
sent_representation.compile(loss='binary_crossentropy',optimizer=adam,metrics=['accuracy'])
          #class_mode='binary')

earlyStopping=EarlyStopping(monitor='val_loss',min_delta=0,patience=0,
                                    verbose=0,mode='auto')

print("Train...")
sent_representation.fit(X_train, y_train,batch_size=batch_size,nb_epoch=20,
          validation_data=(X_test,y_test),callbacks=[earlyStopping])

The output will be a sentiment analysis of 0/1. For that I added a

 sent_representation.add(Dense(1,activation='sigmoid'))

for it to give a binary result.

This is the error we are getting on running the code:

ERROR:
  File "<ipython-input-6-50a1a221497d>", line 18, in <module>
    sent_representation=keras.layers.Add()([lstm_model,attention])

  File "C:\Users\DuttaHritwik\Anaconda3\lib\site-packages\keras\engine\topology.py", line 575, in __call__
    self.assert_input_compatibility(inputs)

  File "C:\Users\DuttaHritwik\Anaconda3\lib\site-packages\keras\engine\topology.py", line 448, in assert_input_compatibility
    str(inputs) + '. All inputs to the layer '

ValueError: Layer add_1 was called with an input that isn't a symbolic tensor. Received type: <class 'keras.models.Sequential'>. Full input: [<keras.models.Sequential object at 0x00000220B565ED30>, <keras.layers.core.Permute object at 0x00000220FE853978>]. All inputs to the layer should be tensors.

Can you have a look and tell us what we are doing wrong here?

Upvotes: 2

Views: 3950

Answers (1)

platinum95
platinum95

Reputation: 408

keras.layers.Add() takes tensors, so at

sent_representation=keras.layers.Add()([lstm_model,attention])

you're passing a sequential model as input and are getting an error. Change your initial layers from using the Sequential model to using the functional api.

lstm_section = Embedding(output_dim=300,input_dim=n_symbols,mask_zero=True, weights=[embedding_weights],input_length=input_length)( input )
lstm_section = Bidirectional(LSTM(300)) ( lstm_section )
lstm_section = Dropout(0.3)( lstm_section ) 
lstm_section = Dense(1,activation='sigmoid')( lstm_section )

lstm_section is a tensor that can then replace lstm_model in your Add() call.

Since you're using the functional API rather than Sequential, you'll also need to create the model, using your_model = keras.models.Model( inputs, sent_representation )

Also worth noting that the attention model in the link you gave multiplies rather than adds, so might be worth using keras.layers.Multiply().

Edit

Just noticed that your attention section also isn't building a graph since you're not passing each layer into the next one. It should be:

attention=Dense(1, activation='tanh')( lstm_section )
attention=Flatten()( attention )
attention=Activation('softmax')( attention )
attention=RepeatVector(64)( attention )
attention=Permute([2, 1])( attention )

Upvotes: 2

Related Questions