user3560220
user3560220

Reputation: 221

Closures in R, calling functions within a function , recursive functions

I am new to R and I am trying out a Classification decision tree using party:ctree library. All seems to be fine. I get the expected result and a well describing plot.

Now if i want to extract the results from the summary of the fit, I ahve to traverse to each node and extract information. Fortunately this is already written by @baydoganm here. I want to extend this code and write the results to a dataframe instead of printing it.

reproducible code :

library(party)
 ct <- ctree(Species ~ ., data = iris)

   traverse <- function(treenode){
        if(treenode$terminal){
           bas=paste(treenode$nodeID,treenode$prediction)
         print(bas) #here the results are printed
         return(0)
                } 

 traverse(treenode$left)
 traverse(treenode$right)
  }

 traverse(ct@tree) #function call

This works fine and i get the output on console. Now if i want to write the results to a data frame, I am facing problems.

What i tried so far: tried to write to a list using mutable closures(). But not sure how to get it working.

l <- list()
count = 0
traverse1 <- function(treenode,l){

if((treenode$terminal == T)){
    count <<- count + 1
    print(count)
    node = c(treenode$nodeID)
    pred = c(treenode$prediction)
    l[[count]] <- data.frame(node,pred) #write results in the dataframe    
  } 

  traverse1(treenode$left,l)
  traverse1(treenode$right,l)

}
test <- traverse1(ct@tree,l)# function call

I get only the results of my last call to the function and rest are null

Upvotes: 1

Views: 98

Answers (2)

Achim Zeileis
Achim Zeileis

Reputation: 17223

If you use the new improved ctree() implementation from the partykit package, then this has all information you need in its fitted component:

library("partykit")
ct <- ctree(Species ~ ., data = iris)
head(fitted(ct))
##   (fitted) (weights) (response)
## 1        2         1     setosa
## 2        2         1     setosa
## 3        2         1     setosa
## 4        2         1     setosa
## 5        2         1     setosa
## 6        2         1     setosa

So for a classification tree you can easily construct the table of absolute frequencies of the response using xtabs() (or table()). And for a regression tree, tapply() could easily be used to get means, medians, etc.

In this case let's look at absolute and relative frequencies in tabular form:

tab <- xtabs(~ `(fitted)` + `(response)`, data = fitted(ct))
tab
##         (response)
## (fitted) setosa versicolor virginica
##        2     50          0         0
##        5      0         45         1
##        6      0          4         4
##        7      0          1        45
ptab <- prop.table(tab, 1)
ptab
##         (response)
## (fitted)     setosa versicolor  virginica
##        2 1.00000000 0.00000000 0.00000000
##        5 0.00000000 0.97826087 0.02173913
##        6 0.00000000 0.50000000 0.50000000
##        7 0.00000000 0.02173913 0.97826087

An alternative route to obtain the frequency table tab would be: table(predict(ct, type = "node"), iris$Species).

If you want to turn any of these into a data frame the as.data.frame() works just fine (probably plus some relabeling of the variables...):

as.data.frame(ptab)
##    X.fitted. X.response.       Freq
## 1          2      setosa 1.00000000
## 2          5      setosa 0.00000000
## 3          6      setosa 0.00000000
## 4          7      setosa 0.00000000
## 5          2  versicolor 0.00000000
## 6          5  versicolor 0.97826087
## 7          6  versicolor 0.50000000
## 8          7  versicolor 0.02173913
## 9          2   virginica 0.00000000
## 10         5   virginica 0.02173913
## 11         6   virginica 0.50000000
## 12         7   virginica 0.97826087

Upvotes: 2

YCR
YCR

Reputation: 4022

Smart way: use assign() to write in the global environment:

require(party) 
ct <- ctree(Species ~ ., data = iris)

tt <- NULL

traverse <- function(treenode){
  if(treenode$terminal){
    bas=paste(treenode$nodeID,treenode$prediction)
    assign("tt", c(tt, bas), envir = .GlobalEnv)
    print(bas) #here the results are printed
    return(0)
  } 

  traverse(treenode$left)
  traverse(treenode$right)
}

traverse(ct@tree) #function call

data.frame(node.id = unlist(lapply(str_split(tt, " "), function(x) x[[1]]))
       , prediction = unlist(lapply(str_split(tt, " "), function(x) x[[2]])))

Dirty way: use sink() to save your printed output.

sink(file = "test.csv", append = T)
traverse(ct@tree) #function call
sink()

tt <- read.csv("test.csv", header = F)

Upvotes: 2

Related Questions