Reputation: 11
Hi I'm currently trying to extract some of the parents node information stored in party object, like ID. For now i can get the IDs for terminal nodes using:
fit<-rpart(CommuteDistance ~.,data = Clients)
pr<-as.party(fit)
nodeids(pr,terminal=TRUE)
But how I can get the parent IDs? And if is posible how I can get the names of the nodes?
Upvotes: 1
Views: 402
Reputation: 11
Since the node IDs follow a nice pattern, you can determine the parent id simply by parent_id = floor(node_id/2).
Here is a minimal working example to arrive at a table with a mapping of node ids to their parent IDs. In it, I use the rownames_to_column function from the tidyverse to get the node_ids rather than using partykit, but the approach would be similar:
library("rpart")
library("tidyverse")
fit <- rpart(Petal.Length ~ ., data = iris)
get_frame_with_parent <- function(x) {
frame_with_parent <-
x$frame %>%
tibble::rownames_to_column(var = "node_id") %>%
mutate(node_id = as.numeric(node_id),
parent_id = floor(node_id/2))
frame_with_parent
}
frame_with_parent
Getting the node names can be accomplished using labels(fit)
For a minimal working example combining these two things to arrive at a table containing node IDs, node labels, parent IDs and parent labels:
library("rpart")
library("tidyverse")
fit <- rpart(Petal.Length ~ ., data = iris)
get_frame_with_parent <- function(x) {
frame_with_parent <-
x$frame %>%
mutate(node_label = labels(x)) %>%
tibble::rownames_to_column(var = "node_id") %>%
mutate(node_id = as.numeric(node_id),
parent_id = floor(node_id/2))
frame_with_parent <-
frame_with_parent %>%
left_join(
dplyr::select(frame_with_parent, node_id, node_label),
by = c("parent_id" = "node_id"),
suffix = c("", ".y")
) %>%
dplyr::rename(parent_label = node_label.y)
frame_with_parent
}
get_frame_with_parent(fit)
Upvotes: 1
Reputation: 17193
There is no readily available function to extract this conveniently. But it is not hard to simply traverse the recursive partynode
structure and get the custom quantities you are interested in. It also helps to convert the recursive partynode
to a flat list first.
For a reproducible example, consider the following rpart
tree and its party
representation:
library("rpart")
fit <- rpart(Petal.Length ~ ., data = iris)
library("partykit")
pr <- as.party(fit)
Afterwards you can easily convert to as.list(pr$node)
which returns all the information from the recursive partynode
structure. In particular this contains the $id
of each node and the $kids
IDs (if any). Thus we can easily extract these with sapply()
and a custom function:
sapply(as.list(pr$node), function(n) {
if(is.null(n$kids)) c(n$id, NA, NA) else c(n$id, n$kids)
})
## [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9]
## [1,] 1 2 3 4 5 6 7 8 9
## [2,] 2 NA 4 5 NA NA 8 NA NA
## [3,] 3 NA 7 6 NA NA 9 NA NA
This shows in the first column that Node 1 has two kids, Nodes 2 and 3. Node 2 is a terminal node because it does not have kids (second column) while Node 3 has two kids again, Nodes 4 and 7 etc.
Upvotes: 0