Hussain Ali
Hussain Ali

Reputation: 183

Why not use mean squared error for classification problems?

I am trying to solve a simple binary classification problem using LSTM. I am trying to figure out the correct loss function for the network. The issue is, when I use the binary cross-entropy as loss function, the loss value for training and testing is relatively high as compared to using the mean squared error (MSE) function.

Upon research, I came across justifications that binary cross-entropy should be used for classification problems and MSE for the regression problem. However, in my case, I am getting better accuracies and lesser loss value with MSE for binary classification.

I am not sure how to justify these obtained results. Why not use mean squared error for classification problems?

Upvotes: 18

Views: 16150

Answers (5)

SharmaTu
SharmaTu

Reputation: 1

The answer lies in how the problem is designed. Regression objective function is designed to predict a value on continuos scale and fit a curve in space through a function $ \Sigma (f(x_i) - y_i)^2 $ which derives from the normal distribution of the target, error and other regression assumptions. The classification objective function derives from the bernoulli assumption of target variable. The objective function is designed to fit a space separation plane between classes unlike the curve fitting in regression. And the objective function is derived by multiplying the probability of success (y=1) and failure (y=0), where success probability is calculated using the sigmoid function. If incase we try to use Mean squared Error loss or the normal distribution of sigmoid computation over the target variable, the optimization process leads to a non-convex optimization problem. This does not guarantee a global minima and leads to multiple solutions. Refer this NIPS paper on "Exponentially many local minima for single neurons": https://papers.nips.cc/paper_files/paper/1995/hash/3806734b256c27e41ec2c6bffa26d9e7-Abstract.html

Upvotes: 0

user2301346
user2301346

Reputation: 480

The answer is right there in your question. Value of binary cross entropy loss is higher than rmse loss.

Case 1 (Large Error):

Lets say your model predicted 1e-7 and the actual label is 1.

Binary Cross Entropy loss will be -log(1e-7) = 16.11.

Root mean square error will be (1-1e-7)^2 = 0.99.

Case 2 (Small Error)

Lets say your model predicted 0.94 and the actual label is 1.

Binary Cross Entropy loss will be -log(0.94) = 0.06.

Root mean square error will be (1-1e-7)^2 = 0.06.

In Case 1 when prediction is far off from reality, BCELoss has larger value compared to RMSE. When you have large value of loss you'll have large value of gradients, thus optimizer will take a larger step in direction opposite to gradient. Which will result in relatively more reduction in loss.

Upvotes: 7

ChrisZZ
ChrisZZ

Reputation: 2141

Though @nerd21 gives a good example for "MSE as loss function is bad for 6-class classification", it's not the same for binary classification.

Let's just consider binary classification. Label is [1, 0], one prediction is h1=[p, 1-p], another prediction is h2=[q, 1-q], thus their's MSEs are:

L1 = 2*(1-p)^2, L2 = 2*(1-q)^2

Assuming h1 is mis-classifcation, i.e. p<1-p, thus 0<p<0.5 Assuming h2 is correct-classification, i.e. q>1-q, thus 0.5<q<1 Then L1-L2=2(p-q)(p+q-2) > 0 is for sure: p < q is for sure; q + q < 1 + 0.5 < 1.5, thus p + q - 2 < -0.5 < 0; thus L1-L2>0, i.e. L1 > L2

This mean for binary classfication with MSE as loss function, mis-classification will definitely with larger loss that correct-classification.

Upvotes: 4

vipin bansal
vipin bansal

Reputation: 896

I'd like to share my understanding of the MSE and binary cross-entropy functions.

In the case of classification, we take the argmax of the probability of each training instance.

Now, consider an example of a binary classifier where model predicts the probability as [0.49, 0.51]. In this case, the model will return 1 as the prediction.

Now, assume that the actual label is also 1.

In such a case, if MSE is used, it will return 0 as a loss value, whereas the binary cross-entropy will return some "tangible" value. And, if somehow with all data samples, the trained model predicts a similar type of probability, then binary cross-entropy effectively return a big accumulative loss value, whereas MSE will return a 0.

According to the MSE, it's a perfect model, but, actually, it's not that good model, that's why we should not use MSE for classification.

Upvotes: -3

nerd21
nerd21

Reputation: 153

I would like to show it using an example. Assume a 6 class classification problem.

Assume, True probabilities = [1, 0, 0, 0, 0, 0]

Case 1: Predicted probabilities = [0.2, 0.16, 0.16, 0.16, 0.16, 0.16]

Case 2: Predicted probabilities = [0.4, 0.5, 0.1, 0, 0, 0]

The MSE in the Case1 and Case 2 is 0.128 and 0.1033 respectively.

Although, Case 1 is correctly predicting class 1 for the instance, the loss in Case 1 is higher than the loss in Case 2.

Upvotes: 8

Related Questions