rmania
rmania

Reputation: 113

How to find path from getTree function from root to leaf in randomForest in R?

The getTree function in randomForest package in R displays the structure of the a particular tree used in the random forest.

Here is an example on the iris dataset

library(randomForest)
data(iris)
rf <- randomForest(Species ~ ., iris)
getTree(rf, 1)

This shows the output of tree #1 of 500:

   left daughter right daughter split var split point status prediction
1              2              3         3        2.50      1          0
2              0              0         0        0.00     -1          1
3              4              5         4        1.65      1          0
4              6              7         4        1.35      1          0
5              8              9         3        4.85      1          0
6              0              0         0        0.00     -1          2
7             10             11         2        3.10      1          0
8             12             13         4        1.55      1          0
9              0              0         0        0.00     -1          3
10             0              0         0        0.00     -1          3
11             0              0         0        0.00     -1          2
12            14             15         2        2.55      1          0
13             0              0         0        0.00     -1          2
14            16             17         2        2.35      1          0
15             0              0         0        0.00     -1          3
16             0              0         0        0.00     -1          3
17             0              0         0        0.00     -1          2

Now my main aim is to find the path from Node 1 to a terminal node (in these cases 2,6,9,10 etc)

Is there a common algorithm or code I can use?

Path for 9 will be 1 -> 3 -> 5 -> 9 Path for 10 will be 1 -> 3 -> 6 -> 10

Any help will be much appreciated.

Upvotes: 2

Views: 1659

Answers (2)

Soren Havelund Welling
Soren Havelund Welling

Reputation: 1893

This answer only dilates the answer from user 'Enprofylline' in terms of finding least lowest common ancestor.

library(randomForest)
data(iris)
set.seed(123) # for reproducibility
rf <- randomForest(Species ~ ., iris)

some_tree <- getTree(rf, 1)

some_tree

get_path_to_node_v2 <- function(tree, child){   
  parent <- which(tree[,'left daughter']==child | tree[,'right daughter']==child)
  if( parent==1 ) return(c(1,child))
  return(c(get_path_to_node_v2(tree, child=parent),child))    
}
lowest_ancest <- function(tree,child_A,child_B) {
  path_A = get_path_to_node_v2(tree,child_A)
  path_B = get_path_to_node_v2(tree,child_B)
  max(path_A[(path_A %in% path_B)])
}
lowest_ancest(some_tree,11,15)

Upvotes: 0

C8H10N4O2
C8H10N4O2

Reputation: 19005

You can accomplish this via recursion -- something like:

library(randomForest)
data(iris)
set.seed(123) # for reproducibility
rf <- randomForest(Species ~ ., iris)

some_tree <- getTree(rf, 1)

some_tree

get_path_to_node <- function(tree, child){   
  parent <- which(tree[,'left daughter']==child | tree[,'right daughter']==child)
  if( parent==1 ) return(paste(parent, child,   sep='->'))
  return( paste(get_path_to_node(tree, child=parent), child, sep='->' ) )    
}

get_path_to_node(some_tree, 5)

gives you 1->3->5

Explanation: We start with a node j. We can find out what its "parent" is by finding out which row has left daughter equal to j or right daughter equal to j. We then repeat the process for its parent, and so forth, until we find that the parent is 1, which by definition is the root. We use paste with sep='->' to build the chain as we go.

Upvotes: 2

Related Questions