Reputation: 291
I have a xgboost.dump text file for many trees. I want to find all paths to get value for each path. Here is one of trees.
tree[0]:
0:[a<0.966398] yes=1,no=2,missing=1
1:[b<0.323071] yes=3,no=4,missing=3
3:[c<0.461248] yes=7,no=8,missing=7
7:leaf=0.00972768
8:leaf=-0.0179376
4:[a<0.379082] yes=9,no=10,missing=9
9:leaf=0.0146003
10:leaf=0.0454369
2:[b<0.322352] yes=5,no=6,missing=5
5:[c<0.674868] yes=11,no=12,missing=11
11:leaf=0.0497964
12:leaf=0.00953781
6:[f<0.598267] yes=13,no=14,missing=13
13:leaf=0.0504545
14:leaf=0.0867654
I want to transform all path into
path1, a<0.966398, b<0.323071, c<0.461248, leaf = 0.00097268
path2, a<0.966398, b<0.323071, c>0.461248, leaf = -0.0179376
path3, a<0.966398, b>0.323071, a<0.379082, leaf = 0.0146003
path4, a<0.966398, b>0.323071, a>0.379082, leaf = 0.0454369
path5, a>0.966398, b<0.322352, c<0.674868, leaf = 0.0497964
path6, a>0.966398, b<0.322352, c>0.674868, leaf = 0.00953781
path7, a>0.966398, b>0.322352, f<0.598267, leaf = 0.0504545
path8, a>0.966398, b>0.322352, f>0.598267, leaf = 0.0864654
I have already tried to list all possible path like
array([[ 0, 1, 3, 7],
[ 0, 1, 3, 8],
[ 0, 1, 4, 9],
[ 0, 1, 4, 10],
[ 0, 2, 5, 11],
[ 0, 2, 5, 12],
[ 0, 2, 6, 13],
[ 0, 2, 6, 14]])
But this way would lead to error once max_depth is higher, some branch will stop growing and the path will be wrong. So I need to parse the yes, no in the text file to generate the real, correct path. Any suggestions? Thank you!
Upvotes: 2
Views: 1224
Reputation: 346
Here's the way that I approached this problem using the R implementation. Users of other languages can follow the logic and replicate in kind.
First, I started with the model dump file generated by xgb.model.dt.tree().
Then, I wrote a function to parse valid pathways from an arbitrary node towards the ultimate parent within an individual trees of the dumped model.
Later, I apply this function to all terminal node "Leaf" records from the model dump using purrr::by_row(), and transform the results for purposes.
This function takes two arguments, one for the tree it is testing, and the other for the identity of the terminal node. It follows the following general steps:
In my case, I apply this function to all "Leaf" nodes in the model dump using purrr::by_row() while .collating = "rows" to represent the pathway as additional rows in the output.
This is also very likely not the fastest way possible.
An increase in nrounds or max_depth in the xgb.booster model will result in increased runtime of this process. You can develop your method using a subset of trees (xgb.model.dt.tree()'s argument n_first_tree = N) to let you estimate the time required to parse out entirety of terminal node pathways across the final model. In my case, models with ~500 trees at a max_depth = 5 can take upwards of 30 minutes.
Upvotes: 2