Reputation: 229
I'm trying to get a plot of a single tree from an XGboost model fit in mlr3
, and I can't seem to find any examples. I know there's a way using the underlying xgboost
(see here), which mlr3
calls, but I can't find any ways to pull out the structure of the trees from an mlr3
model.
As an example:
#XGBoost tree plotting example
data(iris) #Iris dataset
sppNum <- as.numeric(iris$Species)-1 #Species as numeric (0-3), required for multiclass predictions
#Testing/training indices
set.seed(1)
testID <- as.vector(sapply(0:2,function(x) sort(sample(1:50,10,replace = FALSE)+(x*50))))
trainID <- which(!1:150 %in% testID)
# XGboost example ---------------------------------------------------------
library(xgboost)
library(DiagrammeR)
#Make testing/training datasets
datTrain <- xgb.DMatrix(data=as.matrix(iris[trainID,1:4]),label=sppNum[trainID])
datTest <- xgb.DMatrix(data=as.matrix(iris[testID,1:4]),label=sppNum[testID])
watchlist <- list(train = datTrain, eval = datTest)
#Parameter list
parList <- list(max_depth = 3, eta = 1, verbose = 0, nthread = 1,
objective = "multi:softmax", eval_metric = "auc", num_class=3)
xgbMod <- xgb.train(parList, datTrain, nrounds = 100, watchlist) #Fit model
xgb.plot.tree(model = xgbMod, trees = 1) #Plot single tree
# mlr3 example ------------------------------------------------------------
library(mlr3)
library(mlr3learners)
library(mlr3viz)
data(iris)
tsk_mlr3 <- as_task_classif(iris,target='Species') #Set up task
#Set up learner
lrn_mlr3 <- lrn('classif.xgboost',nrounds=100,max_depth = 3, eta = 1,
eval_metric='auc')
lrn_mlr3$train(tsk_mlr3,row_ids = trainID) #Train learner on subset
lrn_mlr3$predict(tsk_mlr3,row_ids = testID) #Predict
xgb.plot.tree(lrn_mlr3$model, trees=1) #Displays the following error:
#Error in xgb.plot.tree(lrn_mlr3$model, trees = 1) :
# model: Has to be an object of class xgb.Booster
Is there any way to plot (at least one) of the trees from XGboost in mlr3
?
Upvotes: 0
Views: 70
Reputation: 643
The issue is just that you need to specify model=
this code works for me.
xgb.plot.tree(model = lrn_mlr3$model, trees=1)
I tried debugonce(xgb.plot.tree)
and found the model argument was NULL so it failed a check on inheriting from xgb.Booster.
Upvotes: 1