dominikdw
dominikdw

Reputation: 11

Random Forest Regression: Extracting the training samples in the terminal nodes of each tree

I want to implement the Predictive Prescription approach from Bertsimas et al. (2020) where they combine a machine learning approach with optimization. For that, I need to look at the terminal nodes (disjuct regions) of each decision tree in the forest.

Specifically, I want to know the following things for each tree:

  1. In which region do the training samples fall?
  2. To which region do the test samples belong?

I hope my question becomes clearer with the following picture of one decision tree:

Regression Tree Example

Here, for the first terminal node, I am not interested in the prediction m but rather in the values y1, y4 and y5 that form the basis for the prediction.


The perfect result would be a matrix-like structure, where each column represents one tree and each row represents one training (test) sample. For each sample and tree, the structure should give me the ID of the region/terminal node where the sample can be found!

I looked at the randomForest as well as the ranger package but had no luck finding anything relevant... some paper mentioned implementing the method with the caret package, but they didn't mention anything on how to bypass the prediction.

Here's a reproducible regression example using ranger:

library(MASS)
library(e1071)
library(ranger)

#load data
data(Boston)
set.seed(111)
ind <- sample(2, nrow(Boston), replace = TRUE, prob=c(0.8, 0.2))

train <- Boston[ind == 1,]
test <- Boston[ind == 2,]

#train random forest
boston.rf <- ranger(medv ~ ., data = train) 

Any help is highly appreciated. Cheers!

Upvotes: 1

Views: 397

Answers (1)

smaica
smaica

Reputation: 817

one way I've found so far to get this information is to use the randomForest package with the option keep.inbag=T - this allows you to retrieve the information which samples are used in order to create each tree - and the method getTree to retrieve the tree structure of each tree in the forest.

I created a function to retrieve the terminal node id given the tree structure from getTree.

# function to retrieve the terminal node id given a rf tree structure and a sample (with numerical only features)
get_terminal_node_id_for_sample <- function(tree, sample){
  node_id=1
  search <- TRUE
  while(search){
    if(tree$status[node_id]=="-1"){
      search <- FALSE
      break
    }
    if(sample[as.character(tree$split.var[node_id])] < tree$split.point[node_id]){
      node_id <- as.numeric(tree$left.daughter[node_id])
    } else {
      node_id <- as.numeric(tree$right.daughter[node_id])
    }
  }
  return(node_id)
}

And used it like this:

library(randomForest)
library(MASS)
library(e1071)

# load data
data(Boston)
set.seed(111)
ind <- sample(2, nrow(Boston), replace = TRUE, prob=c(0.8, 0.2))

train <- Boston[ind == 1,]
test <- Boston[ind == 2,]

# train random forest and keep inbag information
model = randomForest(medv~.,data = train,
                     keep.inbag=T)

# get the first tree of the forest
treeind <- 1
tree <- data.frame(getTree(model, k=treeind, labelVar=TRUE))

# loop over each sample in inbag of the first tree
for (sampleind in which(model$inbag[,treeind]>0)){
  sample <- train[sampleind,]
  node_id <- get_terminal_node_id_for_sample(tree,sample)
  
  ##########################
  # do whatever with node_id
  ##########################
  
  print(paste("sample",sampleind,"is in terminal node",node_id,sep=" "))
}

Need to mention: I've tested this only for numerical features.

Upvotes: 1

Related Questions