ch.elahe
ch.elahe

Reputation: 299

Misclassified samples in the final model of train function in Caret

The train function in Caret package returns a final model and I would like to find the row index of misclassified samples in my main data frame. I do the cross validation as the following:

library(caret)
train_control <- trainControl(method="cv", number=5,savePredictions =  TRUE,classProbs = TRUE)
output <- train(Species~., data=iris, trControl=train_control, method="rf")

and then the final model would be:

> output$finalModel
Call:
randomForest(x = x, y = y, mtry = param$mtry) 
           Type of random forest: classification
                 Number of trees: 500
No. of variables tried at each split: 4

OOB estimate of  error rate: 4.67%
Confusion matrix:
             setosa versicolor virginica class.error
setosa         50          0         0        0.00
versicolor      0         47         3        0.06
virginica       0          4        46        0.08

Is there a way to find out which samples are misclassified? (3 and 4 samples in confusion matrix above)

Upvotes: 2

Views: 1099

Answers (2)

raha.rah
raha.rah

Reputation: 428

another easy way is to check the predicted samples:

output$output$finalModel$predicted

Then you can compare the predicted ones with your main iris data

Upvotes: 1

Samuel
Samuel

Reputation: 3053

Try this:

library(dplyr)
output$pred %>% filter_("pred!=obs")

Output:

         pred        obs setosa versicolor virginica rowIndex mtry Resample
1   virginica versicolor      0      0.084     0.916       71    2    Fold1
2  versicolor  virginica      0      0.976     0.024      107    2    Fold1
3   virginica versicolor      0      0.074     0.926       71    3    Fold1
4  versicolor  virginica      0      0.990     0.010      107    3    Fold1
5  versicolor  virginica      0      0.504     0.496      130    3    Fold1
6   virginica versicolor      0      0.070     0.930       71    4    Fold1
7  versicolor  virginica      0      0.992     0.008      107    4    Fold1
8  versicolor  virginica      0      0.550     0.450      130    4    Fold1
9   virginica versicolor      0      0.244     0.756       78    2    Fold2
10  virginica versicolor      0      0.172     0.828       78    3    Fold2
11  virginica versicolor      0      0.196     0.804       78    4    Fold2
12 versicolor  virginica      0      0.922     0.078      120    2    Fold3
13 versicolor  virginica      0      0.616     0.384      135    2    Fold3
14 versicolor  virginica      0      0.928     0.072      120    3    Fold3
15 versicolor  virginica      0      0.612     0.388      135    3    Fold3
16 versicolor  virginica      0      0.930     0.070      120    4    Fold3
17 versicolor  virginica      0      0.566     0.434      135    4    Fold3
18  virginica versicolor      0      0.352     0.648       84    2    Fold5
19  virginica versicolor      0      0.316     0.684       84    3    Fold5
20  virginica versicolor      0      0.256     0.744       84    4    Fold5

Note that mtry is the number of variables randomly sampled as candidates at each split, and that Resample lists the cross-validation fold.

Let's plot the misclassified items:

d <- output$pred %>% 
  filter_("pred!=obs") %>% 
  distinct(rowIndex) %>% 
  unlist() %>% sort()

print(unname(d))
# 71  78  84 107 120 130 134 135 139

ggplot(iris, aes(Sepal.Length, Sepal.Width, colour = Species)) + 
  geom_point() + 
  geom_point(data = iris[d, ], aes(x = Sepal.Length, y = Sepal.Width), 
             color = "black")

ggplot(iris, aes(Petal.Length, Petal.Width, colour = Species)) + 
  geom_point() + 
  geom_point(data = iris[d, ], aes(x = Petal.Length, y = Petal.Width), 
             color = "black")

Sepal.Length ~ Sepal.Width

Petal.Length ~ Petal.Width

As can be seen, the plots give a visual explanation to our result.

Upvotes: 0

Related Questions