Reputation: 3117
I'm new to scikit-learn and random forest regression and was wondering if there is an easy way to get the predictions from every tree in a random forest in addition to the combined prediction.
Basically I want to have what in R you can do with the predict.all = True
option.
# Import the model we are using
from sklearn.ensemble import RandomForestRegressor
# Instantiate model with 1000 decision trees
rf = RandomForestRegressor(n_estimators = 1000, random_state = 1337)
# Train the model on training data
rf.fit(train_features, train_labels)
# Use the forest's predict method on the test data
predictions = rf.predict(test_features)
print(len(predictions)) #6565 which is the number of observations my test set has.
I want to have every single prediction of every single tree, not only the mean of them for each prediction.
Is it possible to do it in python?
Upvotes: 4
Views: 2215
Reputation: 3353
Use
import numpy as np
predictions_all = np.array([tree.predict(X) for tree in rf.estimators_])
print(predictions_all.shape) #(1000, 6565) 1000 rows: one for every Tree, 6565 columns, one for every target
This uses the estimators_
-attribute (see Docs), which is a list of all the trained DecisionTreeRegressors. We can then call the predict method on each one of them and save that to an array.
Upvotes: 8