et_
et_

Reputation: 179

sort `caret` models in `bwplot()`

I am plotting box-plots of the accuracy scores of resamples of yearly models trained with caret. The models are named by the years they refer to: 2000, 2001, 2002, ..., 2010. I want the models to appear in the box-plots with ascending order based on the year i.e. name of the model.

The summary of resamples based on the below code

fit.year.res <- resamples(fit.year)
summary(fit.year.res)

looks like this:

enter image description here

But then, the different yearly models in the box plot are not sorted:

scales <- list(x=list(relation="free"), y=list(relation="free"))
bwplot(fit.year.res, scales=scales)

enter image description here

I have tried converting the models element of resamples fit.year.res$models to factor from character, but it didn't make nay difference.

Upvotes: 1

Views: 761

Answers (2)

missuse
missuse

Reputation: 19716

I am not aware of an easy solution using bwplot method from caret package. Perhaps there is one but my lattice skills are lacking. I recommend plotting the boxplots manually using ggplot2. This way you will have much better control over the final plot.

Since you did not post an example with data I will use one of the examples from ?caret:::bwplot.resamples

library(caret)
library(party)
library(RWeka)

load(url("http://topepo.github.io/caret/exampleModels.RData"))

resamps <- resamples(list(CART = rpartFit,
                          CondInfTree = ctreeFit,
                          MARS = earthFit))

bwplot(resamps,
       metric = "RMSE")

produces:

enter image description here

To make the plot manually using ggplot you will need some data manipulation:

library(tidyverse)
resamps$values %>% #extract the values
  select(1, ends_with("RMSE")) %>% #select the first column and all columns with a name ending with "RMSE"
  gather(model, RMSE, -1) %>% #convert to long table
  mutate(model = sub("~RMSE", "", model)) %>% #leave just the model names
  ggplot()+ #call ggplot
  geom_boxplot(aes(x = RMSE, y = model)) -> p1 #and plot the box plot

p1

enter image description here

To set a specific order on the y axis:

p1 +
  scale_y_discrete(limits = c("MARS", "CART", "CondInfTree"))

enter image description here

If you prefer lattice

library(lattice)

resamps$values %>%
  select(1, ends_with("RMSE")) %>%
  gather(model, RMSE, -1) %>%
  mutate(model = sub("~RMSE", "", model)) %>%
  {bwplot(model ~ RMSE, data = .)}

enter image description here

to change the order change the levels of model (this approach also works with ggplot2):

resamps$values %>%
  select(1, ends_with("RMSE")) %>%
  gather(model, RMSE, -1) %>%
  mutate(model = sub("~RMSE", "", model),
         model = factor(model, levels = c("MARS", "CART", "CondInfTree"))) %>%
    {bwplot(model ~ RMSE, data = .)}

enter image description here

Upvotes: 2

StupidWolf
StupidWolf

Reputation: 46898

The function bwplot.resamples is used to generate this plot and if you look at the underlying code, the variables are factorized based on their average performance under the metric of interest.

Below I have the relevant code that does the factorization:

bwplot.resamples <- function (x, data = NULL, models = x$models, metric = x$metric, ...)
{
....
  avPerf <- ddply(subset(plotData, Metric == metric[1]),
                  .(Model),
                  function(x) c(Median = median(x$value, na.rm = TRUE)))
  avPerf <- avPerf[order(avPerf$Median),]

    ......
}

I guess what you need to do is to make the plot manually:

data(BloodBrain)
gbmFit <- train(bbbDescr[,-3], logBBB,"gbm",tuneLength=6,
            trControl = trainControl(method = "cv"),verbose=FALSE)
     
glmnetFit <- train(bbbDescr[,-3], logBBB,"glmnet",tuneLength=6,
            trControl = trainControl(method = "cv"))

rfFit <- train(bbbDescr[,-3], logBBB,"rf",tuneLength=6,
            trControl = trainControl(method = "cv"))

knnFit <- train(bbbDescr[,-3], logBBB,"knn",tuneLength=6,
            trControl = trainControl(method = "cv"))

resamps <- resamples(list(gbm = gbmFit,glmnet=glmnetFit,knn=knnFit,rf=rfFit))

If you plot, you can see they are sorted according to the medians (the solid dot):

bwplot(resamps,metric="MAE")

enter image description here

You can access the values under $values and make a function to plot it, something like below:

plotMet = function(obj,metric,var_order){

mat = obj$values
mat = mat[,grep(metric,colnames(mat))]
colnames(mat) = gsub("[~][^ ]*","",colnames(mat))
boxplot(mat[,var_order],horizontal=TRUE,las=2,xlab=metric)

}

plotMet(resamps,"MAE",c("rf","knn","gbm","glmnet"))

enter image description here

Also not a very good idea to name your models with numbers.. try something like model_2000, model_2001 etc

Upvotes: 2

Related Questions