Reputation: 75
I am creating a decision tree using rpart in R. I can also print out the rules generated by the decision tree using the path.rpart() function.
For the airquality data , i have the output of rules as
$`8`
[1] "root" "Temp< 82.5" "Wind>=7.15" "Solar.R< 79.5"
$`18`
[1] "root" "Temp< 82.5" "Wind>=7.15" "Solar.R>=79.5"
[5] "Temp< 77.5"
$`19`
[1] "root" "Temp< 82.5" "Wind>=7.15" "Solar.R>=79.5"
5] "Temp>=77.5"
$`5`
[1] "root" "Temp< 82.5" "Wind< 7.15"
and so on.
Is there a way i can write a code which puts these constraints on my initial table airquality to get the rows which follow these rules which would be equivalent to
airquality[which(airquality$Temp<82.5 & airquality$Wind>=7.15 & Solar.R<79.5)]
for the first rule.
Any help is greatly appreciated. Thanks in advance.
Upvotes: 1
Views: 515
Reputation: 23608
path.rpart gives a nice overview, but MrFlick already wrote some code to show which observations fall in a specific node. Look here.
This only looks at the rpart tree. For looking at which node the predicted values fall look at this post.
See example code I included. The function comes from the first answer. The last part from the second answer.
library(rpart)
# split kyphosis into 2 for example
train <- kyphosis[1:60, ]
test <- kyphosis[-(1:60), ]
fit <- rpart(Kyphosis ~ Age + Number + Start, data = train)
# Show nodes
print(fit)
# function to show observations that fall in a node
# https://stackoverflow.com/questions/23924051/find-the-data-elements-in-a-data-frame-that-pass-the-rule-for-a-node-in-a-tree-m
subset_rpart <- function (tree, df, nodes) {
if (!inherits(tree, "rpart"))
stop("Not a legitimate \"rpart\" object")
stopifnot(nrow(df)==length(tree$where))
frame <- tree$frame
n <- row.names(frame)
node <- as.numeric(n)
if (missing(nodes)) {
xy <- rpart:::rpartco(tree)
i <- identify(xy, n = 1L, plot = FALSE)
if(i> 0L) {
return( df[tree$where==i, ] )
} else {
return(df[0,])
}
}
else {
if (length(nodes <- rpart:::node.match(nodes, node)) == 0L)
return(df[0,])
return ( df[tree$where %in% as.numeric(nodes), ] )
}
}
subset_rpart(fit, train, 7)
# Find the nodes in which the test observations fall
# https://stackoverflow.com/questions/13690201/how-to-count-the-observations-falling-in-each-node-of-a-tree?lq=1
nodes_fit <- fit
nodes_fit$frame$yval <- as.numeric(rownames(nodes_fit$frame))
testnodes <- predict(nodes_fit, test, type="vector")
print(testnodes)
Upvotes: 3
Reputation: 268
rules = rpart(airquality)
table(rules$where)
airquality[rules$where==6,]
will you give you the split data frame without coding the rules. I am not sure if that is what you are looking for.
Upvotes: 2