jimmy15923
jimmy15923

Reputation: 291

find all path for binary tree from xgboost.dump

I have a xgboost.dump text file for many trees. I want to find all paths to get value for each path. Here is one of trees.

tree[0]:
0:[a<0.966398] yes=1,no=2,missing=1
    1:[b<0.323071] yes=3,no=4,missing=3
        3:[c<0.461248] yes=7,no=8,missing=7
            7:leaf=0.00972768
            8:leaf=-0.0179376
        4:[a<0.379082] yes=9,no=10,missing=9
            9:leaf=0.0146003
            10:leaf=0.0454369
    2:[b<0.322352] yes=5,no=6,missing=5
        5:[c<0.674868] yes=11,no=12,missing=11
            11:leaf=0.0497964
            12:leaf=0.00953781
        6:[f<0.598267] yes=13,no=14,missing=13
            13:leaf=0.0504545
            14:leaf=0.0867654

I want to transform all path into

path1, a<0.966398, b<0.323071, c<0.461248, leaf = 0.00097268
path2, a<0.966398, b<0.323071, c>0.461248, leaf = -0.0179376
path3, a<0.966398, b>0.323071, a<0.379082, leaf = 0.0146003
path4, a<0.966398, b>0.323071, a>0.379082, leaf = 0.0454369
path5, a>0.966398, b<0.322352, c<0.674868, leaf = 0.0497964
path6, a>0.966398, b<0.322352, c>0.674868, leaf = 0.00953781
path7, a>0.966398, b>0.322352, f<0.598267, leaf = 0.0504545
path8, a>0.966398, b>0.322352, f>0.598267, leaf = 0.0864654

I have already tried to list all possible path like

array([[ 0,  1,  3,  7],
       [ 0,  1,  3,  8],
       [ 0,  1,  4,  9],
       [ 0,  1,  4, 10],
       [ 0,  2,  5, 11],
       [ 0,  2,  5, 12],
       [ 0,  2,  6, 13],
       [ 0,  2,  6, 14]])

But this way would lead to error once max_depth is higher, some branch will stop growing and the path will be wrong. So I need to parse the yes, no in the text file to generate the real, correct path. Any suggestions? Thank you!

Upvotes: 2

Views: 1224

Answers (1)

Harney
Harney

Reputation: 346

Here's the way that I approached this problem using the R implementation. Users of other languages can follow the logic and replicate in kind.

First, I started with the model dump file generated by xgb.model.dt.tree().

Then, I wrote a function to parse valid pathways from an arbitrary node towards the ultimate parent within an individual trees of the dumped model.

Later, I apply this function to all terminal node "Leaf" records from the model dump using purrr::by_row(), and transform the results for purposes.

This function takes two arguments, one for the tree it is testing, and the other for the identity of the terminal node. It follows the following general steps:

  1. Beginning with a target (terminal) node on a per-tree basis, find the row that has the target node as a valid child among the c("Yes" "No", "Missing") decision splits.
  2. Concatenate this valid parent node ID into a vector that will be used to trace each step of the pathway from the target node up to the ultimate parent. This vector is returned at the completion of the function.
  3. Next, repeat the "who is my parent" step for each node up the chain until the pathway hits the ultimate parent (this node ID always ends in "-0"), while updating the pathway vector for each new step in the chain.
  4. Once the function hits the terminal node, return() the pathway.

In my case, I apply this function to all "Leaf" nodes in the model dump using purrr::by_row() while .collating = "rows" to represent the pathway as additional rows in the output.

This is also very likely not the fastest way possible.

An increase in nrounds or max_depth in the xgb.booster model will result in increased runtime of this process. You can develop your method using a subset of trees (xgb.model.dt.tree()'s argument n_first_tree = N) to let you estimate the time required to parse out entirety of terminal node pathways across the final model. In my case, models with ~500 trees at a max_depth = 5 can take upwards of 30 minutes.

Upvotes: 2

Related Questions