Reputation: 39
I fit the Multinomial logistic regression model and I'd like to obtain confusion matrix to obtain the accuracy
library("glmnet")
x=data.matrix(train[-1])
y= data.matrix(train[1])
x_test=data.matrix(test[-1])
y_test=unlist(test[1])
fit.glm=glmnet(x,y,family="multinomial",alpha = 1, type.multinomial = "grouped")
cvfit=cv.glmnet(x, y, family="multinomial", type.multinomial = "grouped", parallel = TRUE)
y_predict=unlist(predict(cvfit, newx = x_test, s = "lambda.min", type = "class"))
and then to calculate confusion matrix I use caret library
library("lattice")
library("ggplot2")
library("caret")
confusionMatrix(data=y_predict,reference=y_test)
I am getting this error which I do not know how to solve that
Error in confusionMatrix.default(data = y_predict, reference = y_test) : The data must contain some levels that overlap the reference.
I just post the str of y_predict
and y_test
. They might be helpful
str(y_predict)
chr [1:301, 1] "6" "2" "7" "9" "3" "2" "3" "6" "6" "8" "6" "5" "6" ...
- attr(*, "dimnames")=List of 2
..$ : NULL
..$ : chr "1"
str(y_test)
Factor w/ 10 levels "accessory","activation",..: 6 8 2 9 3 2 3 5 10 8 ...
- attr(*, "names")= chr [1:301] "category1" "category2" "category3" "category4" ...
I use unlist
to avoid getting this error Error: x must be atomic for 'sort.list'
Upvotes: 3
Views: 1984
Reputation: 46968
It would make sense to keep track of your labels, and use that to convert the results from glmnet back to labels, and apply the confusion matrix. I use the iris dataset which has 3 labels:
idx = sample(nrow(iris),100)
train = iris[idx,]
test = iris[-idx,]
We convert the response into a numeric:
x = data.matrix(train[,-5])
y = as.numeric(train[,5]) - 1
x_test = data.matrix(test[,-5])
y_test = as.numeric(test[,5]) - 1
Fit, a bit different here, we get back the probabilities:
cvfit=cv.glmnet(x, y, family="multinomial")
y_predict=predict(cvfit, newx = x_test, s = "lambda.min", type = "response")
In this example, the response is the column Species
, in yours it will be test[,1]
:
ref_labels = test$Species
pred_labels = levels(test$Species)[max.col(y_predict[,,1])]
caret::confusionMatrix(table(pred_labels,ref_labels))
Confusion Matrix and Statistics
ref_labels
pred_labels setosa versicolor virginica
setosa 20 0 0
versicolor 0 12 0
virginica 0 0 18
Upvotes: 0