Reputation: 191
I have a question about XGBoost.
Do you know how to know the number of tree created in XGBoost? Unlike RandomForest, which model maker decides how many trees are made, XGBoost basically continues to create the trees till the loss function reaches certain figure. Therefore I want to know this.
Thank you.
Upvotes: 17
Views: 13255
Reputation: 5387
It may be too late to submit an answer. But nevertheless, I recently did the below:
Load the model and save the model as a json file. Then load the json file and print the num_tree details from the the json file.
You may check the below code snippet: Run it as (after saving the code snippet to a file xgb_tree_count.py): python xgb_tree_count.py model-file-path
import sys
import json
import xgboost as xgb
if len(sys.argv) < 2:
print(f'Usage: {sys.argv[0]} <model-file>')
exit(1)
loaded_model = xgb.Booster()
loaded_model.load_model(sys.argv[1])
loaded_model.save_model('/tmp/a_model.json')
with open('/tmp/a_model.json', 'r') as fp:
jsonrepr = json.load(fp)
print(jsonrepr['learner']['gradient_booster']['model']['gbtree_model_param']['num_trees'])
Upvotes: 0
Reputation: 68
In java, there appears not to be a direct way to do this. You can, however, use the result of a model dump to get the actual number of trees. Using a trained Booster
:
int numberOfTrees = booster.getModelDump("", false, "text").length;
Upvotes: -1
Reputation: 4594
It's a bit crooked, but what I'm currently doing is dump
-ing the model (XGBoost produces a list where each element is a string representation of a single tree), and then counting how many elements are in the list:
# clf is a XGBoost model fitted using the sklearn API
dump_list = clf.get_booster().get_dump()
num_trees = len(dump_list)
Upvotes: 21
Reputation: 3223
This is controlled by you as the user. Is you use the native training API, then this is controlled by num_boost_round
(default is 10) see the docs here:
num_boost_round (int) – Number of boosting iterations.
If you use the sklearn API, then this is controlled by n_estimators
(default is 100) see the doc here:
n_estimators : int Number of boosted trees to fit.
The only caveat is that this is the maximum number of trees to fit the fitting can stop if you set up early stopping criterion. I'm not sure if you use that.
Upvotes: -2