S. Robinson
S. Robinson

Reputation: 229

XGBoost Tree Plots in mlr3

I'm trying to get a plot of a single tree from an XGboost model fit in mlr3, and I can't seem to find any examples. I know there's a way using the underlying xgboost (see here), which mlr3 calls, but I can't find any ways to pull out the structure of the trees from an mlr3 model.

As an example:

#XGBoost tree plotting example

data(iris) #Iris dataset
sppNum <- as.numeric(iris$Species)-1 #Species as numeric (0-3), required for multiclass predictions

#Testing/training indices
set.seed(1)
testID <- as.vector(sapply(0:2,function(x) sort(sample(1:50,10,replace = FALSE)+(x*50)))) 
trainID <- which(!1:150 %in% testID)

# XGboost example ---------------------------------------------------------

library(xgboost)
library(DiagrammeR)
#Make testing/training datasets
datTrain <- xgb.DMatrix(data=as.matrix(iris[trainID,1:4]),label=sppNum[trainID]) 
datTest <- xgb.DMatrix(data=as.matrix(iris[testID,1:4]),label=sppNum[testID])
watchlist <- list(train = datTrain, eval = datTest)

#Parameter list
parList <- list(max_depth = 3, eta = 1, verbose = 0, nthread = 1,
                objective = "multi:softmax", eval_metric = "auc", num_class=3)
xgbMod <- xgb.train(parList, datTrain, nrounds = 100, watchlist) #Fit model

xgb.plot.tree(model = xgbMod, trees = 1) #Plot single tree

Plot of single tree

# mlr3 example ------------------------------------------------------------

library(mlr3)
library(mlr3learners)
library(mlr3viz)

data(iris)

tsk_mlr3 <- as_task_classif(iris,target='Species') #Set up task
#Set up learner
lrn_mlr3 <- lrn('classif.xgboost',nrounds=100,max_depth = 3, eta = 1,
                eval_metric='auc') 
lrn_mlr3$train(tsk_mlr3,row_ids = trainID) #Train learner on subset
lrn_mlr3$predict(tsk_mlr3,row_ids = testID) #Predict

xgb.plot.tree(lrn_mlr3$model, trees=1) #Displays the following error:
#Error in xgb.plot.tree(lrn_mlr3$model, trees = 1) : 
#  model: Has to be an object of class xgb.Booster

Is there any way to plot (at least one) of the trees from XGboost in mlr3?

Upvotes: 0

Views: 70

Answers (1)

alexrai93
alexrai93

Reputation: 643

The issue is just that you need to specify model= this code works for me.

xgb.plot.tree(model = lrn_mlr3$model, trees=1)

I tried debugonce(xgb.plot.tree) and found the model argument was NULL so it failed a check on inheriting from xgb.Booster.

Upvotes: 1

Related Questions