Reputation: 221
I am new to R and I am trying out a Classification decision tree using party:ctree
library. All seems to be fine. I get the expected result and a well describing plot.
Now if i want to extract the results from the summary of the fit, I ahve to traverse to each node and extract information. Fortunately this is already written by @baydoganm here. I want to extend this code and write the results to a dataframe
instead of printing it.
reproducible code :
library(party)
ct <- ctree(Species ~ ., data = iris)
traverse <- function(treenode){
if(treenode$terminal){
bas=paste(treenode$nodeID,treenode$prediction)
print(bas) #here the results are printed
return(0)
}
traverse(treenode$left)
traverse(treenode$right)
}
traverse(ct@tree) #function call
This works fine and i get the output on console. Now if i want to write the results to a data frame, I am facing problems.
What i tried so far: tried to write to a list using mutable closures(). But not sure how to get it working.
l <- list()
count = 0
traverse1 <- function(treenode,l){
if((treenode$terminal == T)){
count <<- count + 1
print(count)
node = c(treenode$nodeID)
pred = c(treenode$prediction)
l[[count]] <- data.frame(node,pred) #write results in the dataframe
}
traverse1(treenode$left,l)
traverse1(treenode$right,l)
}
test <- traverse1(ct@tree,l)# function call
I get only the results of my last call to the function and rest are null
Upvotes: 1
Views: 98
Reputation: 17223
If you use the new improved ctree()
implementation from the partykit
package, then this has all information you need in its fitted
component:
library("partykit")
ct <- ctree(Species ~ ., data = iris)
head(fitted(ct))
## (fitted) (weights) (response)
## 1 2 1 setosa
## 2 2 1 setosa
## 3 2 1 setosa
## 4 2 1 setosa
## 5 2 1 setosa
## 6 2 1 setosa
So for a classification tree you can easily construct the table of absolute frequencies of the response using xtabs()
(or table()
). And for a regression tree, tapply()
could easily be used to get means, medians, etc.
In this case let's look at absolute and relative frequencies in tabular form:
tab <- xtabs(~ `(fitted)` + `(response)`, data = fitted(ct))
tab
## (response)
## (fitted) setosa versicolor virginica
## 2 50 0 0
## 5 0 45 1
## 6 0 4 4
## 7 0 1 45
ptab <- prop.table(tab, 1)
ptab
## (response)
## (fitted) setosa versicolor virginica
## 2 1.00000000 0.00000000 0.00000000
## 5 0.00000000 0.97826087 0.02173913
## 6 0.00000000 0.50000000 0.50000000
## 7 0.00000000 0.02173913 0.97826087
An alternative route to obtain the frequency table tab
would be: table(predict(ct, type = "node"), iris$Species)
.
If you want to turn any of these into a data frame the as.data.frame()
works just fine (probably plus some relabeling of the variables...):
as.data.frame(ptab)
## X.fitted. X.response. Freq
## 1 2 setosa 1.00000000
## 2 5 setosa 0.00000000
## 3 6 setosa 0.00000000
## 4 7 setosa 0.00000000
## 5 2 versicolor 0.00000000
## 6 5 versicolor 0.97826087
## 7 6 versicolor 0.50000000
## 8 7 versicolor 0.02173913
## 9 2 virginica 0.00000000
## 10 5 virginica 0.02173913
## 11 6 virginica 0.50000000
## 12 7 virginica 0.97826087
Upvotes: 2
Reputation: 4022
Smart way: use assign()
to write in the global environment:
require(party)
ct <- ctree(Species ~ ., data = iris)
tt <- NULL
traverse <- function(treenode){
if(treenode$terminal){
bas=paste(treenode$nodeID,treenode$prediction)
assign("tt", c(tt, bas), envir = .GlobalEnv)
print(bas) #here the results are printed
return(0)
}
traverse(treenode$left)
traverse(treenode$right)
}
traverse(ct@tree) #function call
data.frame(node.id = unlist(lapply(str_split(tt, " "), function(x) x[[1]]))
, prediction = unlist(lapply(str_split(tt, " "), function(x) x[[2]])))
Dirty way: use sink()
to save your printed output.
sink(file = "test.csv", append = T)
traverse(ct@tree) #function call
sink()
tt <- read.csv("test.csv", header = F)
Upvotes: 2