fmherla
fmherla

Reputation: 23

partykit: Change the terminal node boxplots to violins

The package partykit offers a plotting function for decision trees plot.constparty(), which can display distributions of the terminal node with boxplots (node_boxplot()), minimal example using the iris dataset below.

library("partykit")
ct <- ctree(Petal.Length ~ Sepal.Length + Sepal.Width, data = iris, stump = TRUE)
plot(ct, terminal_panel = node_boxplot)

I would love to display the boxplots as violin plots. Since you can write your own panel functions, that should actually be possible. However, it seems that the violin plot needs to be setup using grid functions, so I have no clue how to do that. I imagine that this is quite cumbersome work, but I believe that many users would benefit from this panel function. Any suggestions on how to implement that? (A first lead points here: partykit: Change terminal node boxplots to bar graphs that shows mean and standard deviation)

Add on: Assume we had a strategy to plot terminal nodes with violins. How could we apply this strategy to multivariate responses to display violins instead of boxplots. See the following screenshot produced with the function node_mvar(): Decision tree with multivariate response: boxplots produced by node_mvar()

Upvotes: 1

Views: 425

Answers (2)

fmherla
fmherla

Reputation: 23

Here is a version of a node_violinplot() panel-generating function:

node_violinplot <- function (obj, col = "black", fill = "lightgray", bg = "white",
                             width = 0.8, yscale = NULL, ylines = 3, cex = 0.5, id = TRUE,
                             mainlab = NULL, gp = gpar(),
                             col.box = "black", fill.box = "black", fill.median = "white")
{
  y <- obj$fitted[["(response)"]]
  stopifnot(is.numeric(y))
  if (is.null(yscale))
    yscale <- range(y) + c(-0.1, 0.1) * diff(range(y))
  rval <- function(node) {
    nid <- id_node(node)
    dat <- data_party(obj, nid)
    yn <- dat[["(response)"]]
    wn <- dat[["(weights)"]]
    if (is.null(wn))
      wn <- rep(1, length(yn))

    ## compute kernel density estimate
    kde <- stats::density(rep.int(yn, wn), from = yscale[1], to = yscale[2], na.rm = TRUE)
    ## limit kde to range(yn)
    idx <- which(kde$x < range(yn)[2] & kde$x > range(yn)[1])
    kde$y <- kde$y[idx]
    kde$x <- kde$x[idx]

    ## construct polygon coordinates
    width.scalingfactor <- width / 2 / max(kde$y, na.rm = TRUE)
    polX <- c((0.5 - (kde$y * width.scalingfactor)), rev(0.5 + (kde$y * width.scalingfactor)))
    polY <- c(kde$x, rev(kde$x))

    ## compute boxplot characteristics
    x <- boxplot(rep.int(yn, wn), plot = FALSE)

    top_vp <- viewport(layout = grid.layout(nrow = 2, ncol = 3,
                                            widths = unit(c(ylines, 1, 1), c("lines", "null",
                                                                             "lines")), heights = unit(c(1, 1), c("lines",
                                                                                                                  "null"))), width = unit(1, "npc"), height = unit(1,
                                                                                                                                                                   "npc") - unit(2, "lines"), name = paste("node_boxplot",
                                                                                                                                                                                                           nid, sep = ""), gp = gp)
    pushViewport(top_vp)
    grid.rect(gp = gpar(fill = bg, col = 0))
    top <- viewport(layout.pos.col = 2, layout.pos.row = 1)
    pushViewport(top)
    if (is.null(mainlab)) {
      mainlab <- if (id) {
        function(id, nobs) sprintf("Node %s (n = %s)",
                                   id, nobs)
      }
      else {
        function(id, nobs) sprintf("n = %s", nobs)
      }
    }
    if (is.function(mainlab)) {
      mainlab <- mainlab(names(obj)[nid], sum(wn))
    }
    grid.text(mainlab)
    popViewport()
    plot <- viewport(layout.pos.col = 2, layout.pos.row = 2,
                     xscale = c(0, 1), yscale = yscale, name = paste0("node_boxplot",
                                                                      nid, "plot"), clip = FALSE)
    pushViewport(plot)
    grid.yaxis()
    grid.rect(gp = gpar(fill = "transparent"))
    grid.clip()
    ## draw violin
    grid.polygon(unit(polX,"npc"), unit(polY, "native"),
                 gp = gpar(col = col, fill = fill))
    ## draw boxplot
    box.width <- max(polX-0.5, na.rm = TRUE) * 0.08
    grid.rect(unit(0.5, "npc"), unit(x$stats[2], "native"),
              width = unit(box.width, "npc"), height = unit(diff(x$stats[c(2, 4)]), "native"),
              just = c("center", "bottom"),
              gp = gpar(col = col.box, fill = fill.box))
    grid.lines(unit(0.5, "npc"), unit(x$stats[1:2], "native"),
               gp = gpar(col = col))
    grid.lines(unit(0.5, "npc"), unit(x$stats[4:5], "native"),
               gp = gpar(col = col))
    grid.points(unit(0.5, "npc"), unit(x$stats[3], "native"),
                size = unit(0.5, "char"),
                gp = gpar(col = fill.median, fill = fill.median), pch = 19)
    upViewport(2)
  }
  return(rval)
}
class(node_violinplot) <- "grapcon_generator"

And a version of node_mvar_violin() that plots the terminal violins for a multivariate response:

.nobs_party <- function(party, id = 1L) {
  dat <- data_party(party, id = id)
  if("(weights)" %in% names(dat)) sum(dat[["(weights)"]]) else NROW(dat)
}

#' @export
node_mvar_violin <- function(obj, which = NULL, id = TRUE, pop = TRUE, ylines = NULL, mainlab = NULL, varlab = TRUE, bg = "white", terminal_panel_mvar = node_violinplot, ...)
{
  ## obtain dependent variables
  y <- obj$fitted[["(response)"]]

  ## fitted node ids
  fitted <- obj$fitted[["(fitted)"]]

  ## number of panels needed
  if(is.null(which)) which <- 1L:NCOL(y)
  k <- length(which)

  rval <- function(node) {

    tid <- id_node(node)
    nobs <- .nobs_party(obj, id = tid)

    ## set up top viewport
    top_vp <- viewport(layout = grid.layout(nrow = k, ncol = 2,
                                            widths = unit(c(ylines, 1), c("lines", "null")), heights = unit(k, "null")),
                       width = unit(1, "npc"), height = unit(1, "npc") - unit(2, "lines"),
                       name = paste("node_mvar", tid, sep = ""))
    pushViewport(top_vp)
    grid.rect(gp = gpar(fill = bg, col = 0))

    ## main title
    if (is.null(mainlab)) {
      mainlab <- if(id) {
        function(id, nobs) sprintf("Node %s (n = %s)", id, nobs)
      } else {
        function(id, nobs) sprintf("n = %s", nobs)
      }
    }
    if (is.function(mainlab)) {
      mainlab <- mainlab(tid, nobs)
    }

    for(i in 1L:k) {
      tmp <- obj
      tmp$fitted[["(response)"]] <- y[,which[i]]
      if(varlab) {
        nm <- names(y)[which[i]]
        if(i == 1L) nm <- paste(mainlab, nm, sep = ": ")
      } else {
        nm <- if(i == 1L) mainlab else ""
      }
      pfun <- switch(sapply(y, class)[which[i]],
                     "Surv" = node_surv(tmp, id = id, mainlab = nm, ...),
                     "factor" = node_barplot(tmp, id = id, mainlab = nm,  ...),
                     "ordered" = node_barplot(tmp, id = id, mainlab = nm, ...),
                     do.call("terminal_panel_mvar", list(tmp, id = id, mainlab = nm, ...)))
      ## select panel
      plot_vpi <- viewport(layout.pos.col = 2L, layout.pos.row = i)
      pushViewport(plot_vpi)

      ## call panel function
      pfun(node)

      if(pop) popViewport() else upViewport()
    }
    if(pop) popViewport() else upViewport()
  }

  return(rval)
}
class(node_mvar_violin) <- "grapcon_generator"

All in all, the result will look like this:

enter image description here

Upvotes: 1

Achim Zeileis
Achim Zeileis

Reputation: 17203

There are two natural strategies for this:

  1. Write a node_violinplot() panel-generating function similar to node_boxplot().
  2. Use ggplot2 via the ggparty package and leverage the existing geom_violin().

For the first strategy, I would recommend to copy the code of node_boxplot() (including setting its class!) and rename it to, say node_violinplot(). Most of its code is responsible for setting up the right viewport and axis ranges etc. which can all be preserved. And then one would "only" replace the grid.lines() and grid.rect() for drawing the boxes with the calls for drawing the violin. I'm not sure what would be the best way to compute the coordinates for the violin elements, though.

For the second strategy all building blocks are essentially available and just have to be customized to obtain the kind of violinplot that you would want. Fox example:

ggparty with geom_violin and geom_boxplot as geom_node_plot

This plot can be replicated as follows:

## example tree
library("partykit")
ct <- ctree(dist ~ speed, data = cars)

## visualization with ggparty + geom_violin
library("ggparty")
ggparty(ct) +
  geom_edge() +
  geom_edge_label() +
  geom_node_splitvar() +
  geom_node_plot(gglist = list(
    geom_violin(aes(x = "", y = dist)),
    geom_boxplot(aes(x = "", y = dist), coef = Inf, width = 0.1, fill = "lightgray"),
    xlab(""),
    theme_minimal()
  ))

Upvotes: 1

Related Questions