Reputation: 9
I'm building a convolutional neural networtk in order to predict 5 emotions from a data set of faces. After working in the construction of the weights I could get an accuracy of 75%
score = model_2_emotion.evaluate(test_datagen.flow(X_test, Y_test, batch_size = 4))
print('Accuracy: {}'.format(score[1]))
308/308 [==============================] - 17s 56ms/step - loss: 0.6139 - accuracy: 0.7575
Accuracy: 0.7575264573097229
But model_2_emotion.predict(X_test)
returns me this array
array([[0.6594997 , 0.00083318, 0.19473663, 0.08065161, 0.06427888],
[0.6610887 , 0.0008383 , 0.19332188, 0.08035047, 0.06440066],
[0.66172844, 0.00082645, 0.19264877, 0.08032911, 0.06446711],
...,
[0.66067713, 0.00084266, 0.19318439, 0.08052441, 0.06477145],
[0.66050553, 0.00085838, 0.19319515, 0.08056776, 0.06487323],
[0.6602842 , 0.00084602, 0.19372217, 0.08054546, 0.06460217]],
dtype=float32)
Where we can see it's just predecting "correcty" the first emotion (the first column) with the accuracy of 60% and from this array produces me this heat map:
Which I think there is something wrong since its passing through the first emotion. Since I got 75% of accuracy but bad predictions, someone knows what's going on?
Upvotes: 0
Views: 553
Reputation: 1929
Looking at your confusion matrix (this is not called a heat map), seems like your model is only predicting a single class, and that your data is unbalanced.
How many samples you have for each class (is it unbalanced)?
How many epochs is your model training?
how many neurons your neural network have in the last layer (it is supposed to have 5 neurons) ?
Only looking closer to the data/problem (and in the train/test accuracy curve over epochs) a better suggestion could be made, but your problem seems to be Under/Overfiting, and that you can benefit of better theoretical basis.
Take a look on any source about bias-variance trade off.
https://quantdare.com/mitigating-overfitting-neural-networks/
here are some generic tips: get more data, improve pre processing, improve model (more layers, different kernel sizes, skip connections, batch normalization, different optimization/learning rates etc ...).
Upvotes: 1