Reputation: 11
I want to implement the Predictive Prescription approach from Bertsimas et al. (2020) where they combine a machine learning approach with optimization. For that, I need to look at the terminal nodes (disjuct regions) of each decision tree in the forest.
Specifically, I want to know the following things for each tree:
I hope my question becomes clearer with the following picture of one decision tree:
Here, for the first terminal node, I am not interested in the prediction m but rather in the values y1, y4 and y5 that form the basis for the prediction.
The perfect result would be a matrix-like structure, where each column represents one tree and each row represents one training (test) sample. For each sample and tree, the structure should give me the ID of the region/terminal node where the sample can be found!
I looked at the randomForest
as well as the ranger
package but had no luck finding anything relevant... some paper mentioned implementing the method with the caret
package, but they didn't mention anything on how to bypass the prediction.
Here's a reproducible regression example using ranger
:
library(MASS)
library(e1071)
library(ranger)
#load data
data(Boston)
set.seed(111)
ind <- sample(2, nrow(Boston), replace = TRUE, prob=c(0.8, 0.2))
train <- Boston[ind == 1,]
test <- Boston[ind == 2,]
#train random forest
boston.rf <- ranger(medv ~ ., data = train)
Any help is highly appreciated. Cheers!
Upvotes: 1
Views: 397
Reputation: 817
one way I've found so far to get this information is to use the randomForest
package with the option keep.inbag=T
- this allows you to retrieve the information which samples are used in order to create each tree - and the method getTree
to retrieve the tree structure of each tree in the forest.
I created a function to retrieve the terminal node id given the tree structure from getTree
.
# function to retrieve the terminal node id given a rf tree structure and a sample (with numerical only features)
get_terminal_node_id_for_sample <- function(tree, sample){
node_id=1
search <- TRUE
while(search){
if(tree$status[node_id]=="-1"){
search <- FALSE
break
}
if(sample[as.character(tree$split.var[node_id])] < tree$split.point[node_id]){
node_id <- as.numeric(tree$left.daughter[node_id])
} else {
node_id <- as.numeric(tree$right.daughter[node_id])
}
}
return(node_id)
}
And used it like this:
library(randomForest)
library(MASS)
library(e1071)
# load data
data(Boston)
set.seed(111)
ind <- sample(2, nrow(Boston), replace = TRUE, prob=c(0.8, 0.2))
train <- Boston[ind == 1,]
test <- Boston[ind == 2,]
# train random forest and keep inbag information
model = randomForest(medv~.,data = train,
keep.inbag=T)
# get the first tree of the forest
treeind <- 1
tree <- data.frame(getTree(model, k=treeind, labelVar=TRUE))
# loop over each sample in inbag of the first tree
for (sampleind in which(model$inbag[,treeind]>0)){
sample <- train[sampleind,]
node_id <- get_terminal_node_id_for_sample(tree,sample)
##########################
# do whatever with node_id
##########################
print(paste("sample",sampleind,"is in terminal node",node_id,sep=" "))
}
Need to mention: I've tested this only for numerical features.
Upvotes: 1