Reputation: 796
I need to extract the decision rules from my fitted xgboost model in python. I use 0.6a2 version of xgboost library and my python version is 3.5.2.
My ultimate goal is to use those splits to bin variables ( according to the splits).
I did not come across any property of the model for this version which can give me splits.
plot_tree
is giving me something similar. However it is visualization of the tree.
I need something like https://stackoverflow.com/a/39772170/4559070 for xgboost model
Upvotes: 21
Views: 19699
Reputation: 705
Below is the code fragment which prints all the rules extracted from the booster trees from xgboost model.
The below code is based on the missing values substitution with 999999
import networkx as nx
df = model._Booster.trees_to_dataframe()
G = nx.DiGraph()
G.add_nodes_from(df.ID.tolist())
yes_edges = df[['ID', 'Yes', 'Feature', 'Split']].dropna()
yes_edges['label'] = yes_edges.apply(lambda x: "({feature} > {value:.2f} or {feature} = 999999)".format(feature=x['Feature'], value=x['Split']), axis=1)
no_edges = df[['ID', 'No', 'Feature', 'Split']].dropna()
no_edges['label'] = no_edges.apply(lambda x: "({feature} < {value:.2f})".format(feature=x['Feature'], value=x['Split']), axis=1)
for v in yes_edges.values:
G.add_edge(v[0],v[1], feature=v[2], expr=v[4])
for v in no_edges.values:
G.add_edge(v[0],v[1], feature=v[2], expr=v[4])
leaf_node_values = {i[0]:i[1] for i in df[df.Feature=='Leaf'][['ID','Gain']].values}
roots = []
leaves = []
for node in G.nodes :
if G.in_degree(node) == 0 : # it's a root
roots.append(node)
elif G.out_degree(node) == 0 : # it's a leaf
leaves.append(node)
paths = []
for root in roots :
for leaf in leaves :
for path in nx.all_simple_paths(G, root, leaf) :
paths.append(path)
pred_conditions = []
for path in paths:
parts = []
for i in range(len(path)-1):
parts.append(G[path[i]][path[i+1]]['expr'])
pred_conditions.append("if " + " and ".join(parts) + " then {value:.4f}".format(value=leaf_node_values.get(path[-1])))
The above code prints every rule in the format as below:
if x>y and a>b and c<d then e
Upvotes: 2
Reputation: 452
You can find the decision rules as a dataframe through the function model._Booster.trees_to_dataframe()
.
The Yes
column contains the ID
of the yes-branch, and the No
column of the no-branch. This way you can reconstruct the tree, since for each row of the dataframe, the node ID
has directed edges to Yes
and No
. You can do that with networkx like so:
import networkx as nx
df = model._Booster.trees_to_dataframe()
# Create graph
G = nx.Graph()
# Add all the nodes
G.add_nodes_from(df.ID.tolist())
# Add the edges. This should be simpler in Pandas, but there seems to be a bug with df.apply(tuple, axis=1) at the moment.
yes_pairs = df[['ID', 'Yes']].dropna()
no_pairs = df[['ID', 'No']].dropna()
yes_edges = [tuple([i[0], i[1]]) for i in yes_pairs.values]
no_edges = [tuple([i[0], i[1]]) for i in no_pairs.values]
G.add_edges_from(yes_edges + no_edges)
Upvotes: 8
Reputation: 11444
It is possible, but not easy. I would recommend you to use GradientBoostingClassifier
from scikit-learn
, which is similar to xgboost
, but has native access to the built trees.
With xgboost
, however, it is possible to get a textual representation of the model and then parse it:
from sklearn.datasets import load_iris
from xgboost import XGBClassifier
# build a very simple model
X, y = load_iris(return_X_y=True)
model = XGBClassifier(max_depth=2, n_estimators=2)
model.fit(X, y);
# dump it to a text file
model.get_booster().dump_model('xgb_model.txt', with_stats=True)
# read the contents of the file
with open('xgb_model.txt', 'r') as f:
txt_model = f.read()
print(txt_model)
It will print you a textual description of 6 trees (2 estimators, each consists of 3 trees, one per class), which starts like this:
booster[0]:
0:[f2<2.45] yes=1,no=2,missing=1,gain=72.2968,cover=66.6667
1:leaf=0.143541,cover=22.2222
2:leaf=-0.0733496,cover=44.4444
booster[1]:
0:[f2<2.45] yes=1,no=2,missing=1,gain=18.0742,cover=66.6667
1:leaf=-0.0717703,cover=22.2222
2:[f3<1.75] yes=3,no=4,missing=3,gain=41.9078,cover=44.4444
3:leaf=0.124,cover=24
4:leaf=-0.0668394,cover=20.4444
...
Now you can, for example, extract all splits from this description:
import re
# trying to extract all patterns like "[f2<2.45]"
splits = re.findall('\[f([0-9]+)<([0-9]+.[0-9]+)\]', txt_model)
splits
It will print you the list of tuples (feature_id, split_value), like
[('2', '2.45'),
('2', '2.45'),
('3', '1.75'),
('3', '1.65'),
('2', '4.95'),
('2', '2.45'),
('2', '2.45'),
('3', '1.75'),
('3', '1.65'),
('2', '4.95')]
You can further process this list as you wish.
Upvotes: 20
Reputation: 83
You need to know the name of your tree, and after that, you can insert it into your code.
Upvotes: -15