Alby
Alby

Reputation: 5752

Why is using "xgbTree" in caret so slow with trainControl?

I am trying to fit xgboost model on multiclass prediction problem, and wanted to use caret to do hyperparameter search. To test the package, I used the following code, and it takes takes 20 seconds, when I do not supply train object with trainControl

# just use one parameter combination
xgb_grid_1 <- expand.grid(
  nrounds = 1,
  eta = 0.3,
  max_depth = 5,
  gamma = 0,
  colsample_bytree=1, 
  min_child_weight=1
)
# train
xgb_train_1 = train(
  x = as.matrix(sparse_train),
  y = conversion_tbl$y_train_c ,
  trControl = trainControl(method="none", classProbs = TRUE, summaryFunction = multiClassSummary),
  metric="logLoss",
  tuneGrid = xgb_grid_1,
  method = "xgbTree"
)

However, when I supply train with a trainControl object, the code never gets finished..or taking a long time(at least it dint' finish for 15 minutes.

xgb_trcontrol_1 <- trainControl(
  method = "cv",
  number = 2,
  verboseIter = TRUE, 
  returnData = FALSE,
  returnResamp = "none",                                         
  classProbs = TRUE,                                           
  summaryFunction = multiClassSummary
)
xgb_train_1 = train(
  x = as.matrix(sparse_train),
  y = conversion_tbl$y_train_c ,
  trControl = xgb_trcontrol_1,
  metric="logLoss",
  tuneGrid = xgb_grid_1,
  method = "xgbTree"
)

Why is this?

FYI, my data size is

 dim(sparse_train)
[1] 702402     36

Upvotes: 2

Views: 8576

Answers (2)

Achekroud
Achekroud

Reputation: 261

Another thing you can try is adding nthread = 1 to the caret::train() call.

Both XGBoost and Caret try to use parallel/multicore processing where possible, and in the past I have found this to (silently) cause too many threads to spawn, throttling your machine.

Telling caret to process models in sequence minimizes the problem and should mean that only xgboost will be spawning threads.

Upvotes: 4

Stenof
Stenof

Reputation: 76

Your trainControl objects are different.

In the first trainControl object, the method is method="none". In the second trainControl object, the method is method="cv" and number=2. So, in the second object, you are running a two-fold cross-validation which takes longer then not running a cross-validation.

Upvotes: 6

Related Questions