user1885116
user1885116

Reputation: 1797

Caret and rpart - definining method

i am trying to familiarize myself with the caret package. I would previously use rpart directly - e.g. with the following syntax

fit_rpart=rpart(y~.,data=dt1,method="anova"). 

i have specified anova as i am aiming for regression (rather than classification)

with caret - i would the following syntax:

rpart_fit <- train(y ~ ., data = dt1, method = "rpart",trControl=fitControl)

my question is, as the method slot is already used, where/how can i still specify method="anova"?

Many thanks in advance!

Upvotes: 0

Views: 2521

Answers (2)

Sander van den Oord
Sander van den Oord

Reputation: 12838

In caret 'method' refers to the type of model you would like to use, so for example rpart or lm (linear regression) or rf (random forest).

What you're referring to is defined as 'metric' in caret. If your y-variable is a continuous variable, the metric will be default set to maximizing RMSE. So you don't have to do anything.

You could also explicitly specify this by:

rpart_fit <- train(y ~ ., data = dt1, method = "rpart",trControl=fitControl, metric="RMSE")

Upvotes: 0

topepo
topepo

Reputation: 14331

You can make a custom method using the current rpart code. First, get the current code:

library(caret)
rpart_code <- getModelInfo("rpart", regex = FALSE)[[1]]

You then just add the extra option to the code. This method is somewhat convoluted since it handles a bunch of different cases, but here is the edit:

rpart_code$fit <- function(x, y, wts, param, lev, last, classProbs, ...) { 
  cpValue <- if(!last) param$cp else 0
  theDots <- list(...)
  if(any(names(theDots) == "control")) {
    theDots$control$cp <- cpValue
    theDots$control$xval <- 0 
    ctl <- theDots$control
    theDots$control <- NULL
  } else ctl <- rpart.control(cp = cpValue, xval = 0)   

  ## check to see if weights were passed in (and availible)
  if(!is.null(wts)) theDots$weights <- wts    

  modelArgs <- c(list(formula = as.formula(".outcome ~ ."),
                      data = if(is.data.frame(x)) x else as.data.frame(x),
                      control = ctl,
                      method = "anova"),
                 theDots)
  modelArgs$data$.outcome <- y

  out <- do.call("rpart", modelArgs)

  if(last) out <- prune.rpart(out, cp = param$cp)
  out           
}

then test:

library(rpart)
set.seed(445)
mod <- train(pgstat ~ age + eet + g2 + grade + gleason + ploidy, 
             data = stagec,
             method = rpart_code,
             tuneLength = 8)

Max

Upvotes: 1

Related Questions