Reputation: 2067
I am trying to create a confusion matrix.
My data looks like:
class Growth Negative Neutral
1 Growth 0.3082588 0.2993632 0.3923780
2 Neutral 0.4696949 0.2918042 0.2385009
3 Negative 0.3608549 0.2679748 0.3711703
4 Neutral 0.3636836 0.2431433 0.3931730
5 Growth 0.4325862 0.2011520 0.3662619
6 Negative 0.2939859 0.2397171 0.4662970
where class
is the "real" obsered result and the Growth
, Negative
and Neutral
are the probabilities that the model predicted it would be in any of these classes. i.e. in the first row the result for Neutral
was 0.3923780
so the model would incorrectly predict this class when it was actually Growth
.
I would usually use the confusionMatrix()
function from caret
but my data is in a slightly different way. Should I create a new column called pred_class
where the column with the highest value get put? some thing like:
class Growth Negative Neutral pred_class
1 Growth 0.3082588 0.2993632 0.3923780 Neutral
2 Neutral 0.4696949 0.2918042 0.2385009 Growth
3 Negative 0.3608549 0.2679748 0.3711703 Neutral
4 Neutral 0.3636836 0.2431433 0.3931730 Neutral
5 Growth 0.4325862 0.2011520 0.3662619 Growth
6 Negative 0.2939859 0.2397171 0.4662970 Neutral
then I can do something like confusionMatrix(df$pred_class, df$class)
. How can I write a function to get the column names pasted into a column depending on the highest probability?
Data:
df <- structure(list(class = c("Growth", "Neutral", "Negative", "Neutral",
"Growth", "Negative", "Neutral", "Neutral", "Neutral", "Neutral",
"Neutral", "Negative", "Neutral", "Growth", "Growth", "Growth",
"Negative", "Negative", "Growth", "Negative"), Growth = c(0.308258818045192,
0.469694864370061, 0.360854910973552, 0.363683641698332, 0.43258619401693,
0.2939858517149, 0.397951949316298, 0.235376278828237, 0.3685791718903,
0.330295647415191, 0.212072592205125, 0.220703558050626, 0.389445269278106,
0.286933037813081, 0.315659629884986, 0.30185119811882, 0.273429057319956,
0.277357131556229, 0.339004410008943, 0.407114176119814), Negative = c(0.299363167088292,
0.291804233603859, 0.267974798034839, 0.243143322044808, 0.201151951415105,
0.239717129555608, 0.351629585705591, 0.258325790152011, 0.281660024058527,
0.189920159505041, 0.265058882513953, 0.433664278547707, 0.114765460651494,
0.402354633060689, 0.370370354887748, 0.3239536031819, 0.3279406609037,
0.327198131828346, 0.298583999674218, 0.337902573718712), Neutral = c(0.392378014866516,
0.23850090202608, 0.371170290991609, 0.39317303625686, 0.366261854567965,
0.466297018729492, 0.250418464978111, 0.506297931019752, 0.349760804051173,
0.479784193079769, 0.522868525280922, 0.345632163401667, 0.4957892700704,
0.31071232912623, 0.313970015227266, 0.374195198699279, 0.398630281776344,
0.395444736615424, 0.362411590316838, 0.254983250161474)), row.names = c(NA,
20L), class = "data.frame")
Upvotes: 1
Views: 798
Reputation: 32558
#Vector of observed values
observed = df$class
#Remove first column from df so that we only have numeric values
temp = df[,-1]
#Obtain the predicted values based on column number
#of the maximum values in each row of temp
predicted = names(temp)[max.col(temp, ties.method = "first")]
#Create a union of the observed and predicted values
#so that all values are accounted for when we do 'table'
lvls = unique(c(observed, predicted))
#Convert observed and predicted values to factor
#with all levels that we created above
observed = factor(x = observed, levels = lvls)
predicted = factor(predicted, levels = lvls)
#Tabulate values
m = table(predicted, observed)
#Run confusionMatrix
library(caret)
confusionMatrix(m)
# Confusion Matrix and Statistics
# observed
# predicted Growth Neutral Negative
# Growth 1 3 1
# Neutral 3 5 4
# Negative 2 0 1
# Overall Statistics
# Accuracy : 0.35
# 95% CI : (0.1539, 0.5922)
# No Information Rate : 0.4
# P-Value [Acc > NIR] : 0.7500
# Kappa : -0.0156
# Mcnemar's Test P-Value : 0.2276
# Statistics by Class:
# Class: Growth Class: Neutral Class: Negative
# Sensitivity 0.1667 0.6250 0.1667
# Specificity 0.7143 0.4167 0.8571
# Pos Pred Value 0.2000 0.4167 0.3333
# Neg Pred Value 0.6667 0.6250 0.7059
# Prevalence 0.3000 0.4000 0.3000
# Detection Rate 0.0500 0.2500 0.0500
# Detection Prevalence 0.2500 0.6000 0.1500
# Balanced Accuracy 0.4405 0.5208 0.5119
Upvotes: 1