SecretIndividual
SecretIndividual

Reputation: 2519

Keras accuracy not increasing

I am trying to perform sentiment classification using Keras. I am trying to do this using a basic neural network (no RNN or other more complex type). However when I run the script I see no increase in accuracy during training/evaluation. I am guessing I am setting up the output layer incorrectly but I am not sure of that. y_train is a list [1,2,3,1,2,4,5] (5 different labels) containing the targets belonging to the features in X_train_seq_padded. The setup is as follows:

padding_len = 24 # len of each tokenized sentence
neurons = 16 # 2/3 the length of the text that is padded
model = Sequential()
model.add(Dense(neurons, input_dim = padding_len, activation = 'relu', name = 'hidden-1'))
model.add(Dense(neurons, activation = 'relu', name = 'hidden-2'))
model.add(Dense(neurons, activation = 'relu', name = 'hidden-3'))
model.add(Dense(1, activation = 'sigmoid', name = 'output_layer'))

model.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics=['accuracy'])

callbacks = [EarlyStopping(monitor = 'accuracy', patience = 5, mode = 'max')]
history = model.fit(X_train_seq_padded, y_train, epochs = 100, batch_size = 64, callbacks = callbacks)

Upvotes: 1

Views: 287

Answers (1)

Innat
Innat

Reputation: 17219

First of all, in your above set up if you choose sigmoid in your last layer activation function which generally uses for binary classification or multi-label classification then, the loss function should be binary_crossentropy.

But if your labels are represented multi-class and transformed into one-hot encoded then your last layer should be Dense(num_classes, activations='softmax') and the loss function would be categorical_crossentropy.

But if you don't transform your multi-class label but integer then your last layer and loss function should be

Dense(num_classes)  # with logits 
SparseCategoricalCrossentropy(from_logits= True) 

Or, (@Frightera)

Dense(num_classes, activation='softmax') # with probabilities 
SparseCategoricalCrossentropy(from_logits=False)

Upvotes: 1

Related Questions