Sankalp
Sankalp

Reputation: 75

To get the tuples which follow a particular rule in a decision tree in R

I am creating a decision tree using rpart in R. I can also print out the rules generated by the decision tree using the path.rpart() function.

For the airquality data , i have the output of rules as

$`8`
[1] "root"          "Temp< 82.5"    "Wind>=7.15"    "Solar.R< 79.5"

$`18`
[1] "root"          "Temp< 82.5"    "Wind>=7.15"    "Solar.R>=79.5"
[5] "Temp< 77.5"   

$`19`
[1] "root"          "Temp< 82.5"    "Wind>=7.15"    "Solar.R>=79.5"
5] "Temp>=77.5"   

$`5`
[1] "root"       "Temp< 82.5" "Wind< 7.15"

and so on.

Is there a way i can write a code which puts these constraints on my initial table airquality to get the rows which follow these rules which would be equivalent to

airquality[which(airquality$Temp<82.5 & airquality$Wind>=7.15 & Solar.R<79.5)]

for the first rule.

Any help is greatly appreciated. Thanks in advance.

Upvotes: 1

Views: 515

Answers (2)

phiver
phiver

Reputation: 23608

path.rpart gives a nice overview, but MrFlick already wrote some code to show which observations fall in a specific node. Look here.

This only looks at the rpart tree. For looking at which node the predicted values fall look at this post.

See example code I included. The function comes from the first answer. The last part from the second answer.

library(rpart)

# split kyphosis into 2 for example
train <- kyphosis[1:60, ]
test <- kyphosis[-(1:60), ]
fit <- rpart(Kyphosis ~ Age + Number + Start, data = train)

# Show nodes
print(fit)

# function to show observations that fall in a node
# https://stackoverflow.com/questions/23924051/find-the-data-elements-in-a-data-frame-that-pass-the-rule-for-a-node-in-a-tree-m
subset_rpart <- function (tree, df, nodes) {
      if (!inherits(tree, "rpart")) 
            stop("Not a legitimate \"rpart\" object")
      stopifnot(nrow(df)==length(tree$where))
      frame <- tree$frame
      n <- row.names(frame)
      node <- as.numeric(n)

      if (missing(nodes)) {
            xy <- rpart:::rpartco(tree)
            i <- identify(xy, n = 1L, plot = FALSE)
            if(i> 0L) {
                  return( df[tree$where==i, ] )
            } else {
                  return(df[0,])
            }
      }
      else {
            if (length(nodes <- rpart:::node.match(nodes, node)) == 0L) 
                  return(df[0,])
            return ( df[tree$where %in% as.numeric(nodes), ] )
      }
}

subset_rpart(fit, train, 7)


# Find the nodes in which the test observations fall 
# https://stackoverflow.com/questions/13690201/how-to-count-the-observations-falling-in-each-node-of-a-tree?lq=1
nodes_fit <- fit
nodes_fit$frame$yval <- as.numeric(rownames(nodes_fit$frame))
testnodes <- predict(nodes_fit, test, type="vector")
print(testnodes)

Upvotes: 3

Sidhha
Sidhha

Reputation: 268

rules = rpart(airquality)    
table(rules$where)
airquality[rules$where==6,]

will you give you the split data frame without coding the rules. I am not sure if that is what you are looking for.

Upvotes: 2

Related Questions