user3357059
user3357059

Reputation: 1192

Remove text.y from box plot when generating plot with the ggparty package

I am tring to remove the text in the Y axis on some of the bar plots. I have updated the options in the 'scales' and the 'shared_axis_labels' from the 'geom_node_plot' fuction to no avail. Below some code to illustrate the issue and a plot of the labels that I want to remove.

library(ggparty)
library(tidyverse)
library(partykit)

ct <- ctree(Species ~ ., data = iris)

panel_prop <- function(count, panel) {
  count / tapply(count, panel, sum)[as.character(panel)]
}

ggparty(ct) +
  geom_edge() +
  geom_edge_label(colour = "gray9", size = 3) +
  geom_node_plot(scales = "fixed",
                 shared_axis_labels = FALSE,
                 gglist = list( aes(
                   y = Species,
                   x = after_stat(panel_prop(count, PANEL))
                   ,fill = Species
                   ,label = after_stat(scales::percent(panel_prop(count, PANEL), accuracy = 1))
                 ),
                 geom_bar(),
                 geom_text(stat = "count", hjust = 0, size = 3),
                 coord_cartesian(clip = "off"),
                 scale_x_continuous(labels = scales::percent_format(accuracy = 1),
                                    expand = expansion(mult = c(.05, .25))),
                 theme(axis.text.x = element_text(size = 5)),
                 xlab(""),
                 ylab(""))) +
  geom_node_label(aes(),
                  line_list = list(aes(label = paste("Node", id)),
                                   aes(label = splitvar),
                                   aes(label = "")
                  ),  line_gpar = list(list(size = 8,
                                            col = "black"#, fontface = "bold"
                  ),  list(size = 6), list(size = 8)
                  ),  ids = "inner")

enter image description here

Upvotes: 2

Views: 85

Answers (1)

Kat
Kat

Reputation: 18754

I have a solution that works, but I think it's more than a bit overkill. I was sure there was a relatively easy way to make that happen...ya, no, I didn't find one...

I tried to make this solution dynamic, but there are an endless amount of things you can do with ggplot, so I'm sure it won't work in many situations that differ from your specific question.

There is an assumption that there is only one call for geom_node_plot (not sure if it ggparty let's you have more than one...). Another assumption is that if there are labels, that they are percentages. While I called to look specifically for geom_text, I didn't add in geom_label (the whole, you can have one, but not the other....and ya, another rabbit hole I was lost in for a while... I digress.)

The function fixer takes in the graph you made, picks it apart and remakes it. The plot has to be in the environment first, so chaining the fixer function will give you an error (e.g., can use ggplot... %>% fixer())

Unfortunately it's not a ggplot object when it's finished, it's a gtree... take a look and let me know if you have any questions.

Your originally coded plot is unchanged, except in that it is assigned to gg.

updated plot

library(ggparty)  # install.packages("ggparty")
library(tidyverse)
library(partykit)

# fixer function for modifying the plot
fixer <- function(gg){ 
  # assumes there's only 1 node_plot layer; assumes percentage labels (if labels called)
  require(scales)
  require(grid)
  # split the tree from the plot
  constructors <- map(gg$layers, \(k) {  # identify which layers are which
    gimme <- as.list(k$constructor)
    tellMe <- str_detect(as.character(gimme[[1]]), 'plot')
    if(isTRUE(tellMe)) {
      res = 'p'
    } else {
      res = 'np'
    }
    res
  }) %>% unlist()   # this worked to get the layers
  
  pL <- which(constructors == "p"); nL <- which(constructors == "np") 
  pplt <- nplt <- gg       # create copies of the original plot, to create separated plots
  
  nplt$layers[pL] <- NULL  # node plot
  pplt$layers[nL] <- NULL  # graph plot
  
  dta <- pplt$data %>% filter(level == max(level))  # filter for terminal node data
  dta <- dta[, str_detect(names(dta), 'nodedata')]  # select columns for node data
  
  framer <- map(1:nrow(dta), \(k) {     # extract terminal node data percentages
    dtb <- dta[k, 1] %>% unlist()
    pct <- as.data.frame(t(summary(dtb)))/length(dtb)
    pct$grp <- paste0('grp', k)
    pct
  }) %>% list_rbind()
  
  framer <- pivot_longer(framer, -grp) # format data for plotting
  
  # extract pertinent details from ggplot layer
  plotLayer <- gg$layers[[which(constructors == 'p')]]$constructor %>% 
    as.list()
  
  pltList <- plotLayer[['gglist']] %>% as.list()
  
  pltO <- ggplot(framer, aes(value, name, group = grp, fill = name))
  
  # rebuild the geom_node_plot layer
  map(2:length(pltList), \(k) { # 1 is [[1]] list; skip it
    plst <- pltList[[k]] %>% as.list()
    if (length(plst) == 1) {
      if(as.character(plst) == 'geom_bar') pltO <<- pltO + geom_col()
      else pltO <<- pltO + eval(pltList[[k]])      # other calls can be left as written
    } else if (is.null(names(plst))) pltO <<- pltO + eval(pltList[[k]]) # eval as written
    else {
      thats <- as.character(plst[[1]])
      if(thats == 'geom_text') {
        iL <- list(data = framer[framer$value != 0, ], # initial text list
                   mapping = aes(value, name, group = grp,
                                 label = label_percent(accuracy = 1)(value)))
        plst <- plst[!names(plst) %in% c('stat', 'aes')]       # not the geom, stat or aes
        pltO <<- pltO + do.call(geom_text, append(iL, eval(plst[-1]))) # add remainder of geom
      } else if (thats == 'theme') {
        iL <- list(panel.spacing.x = unit(1, 'lines'),  # make space for xaxis text
                   strip.background = element_blank(),  # remove strip bg from facets
                   strip.text = element_blank(),        # remove strip text from facets
                   legend.position = 'none',            # no legend
                   plot.margin = margin(0, 15, 15, 15)) # make space everywhere except on top
        pltO <<- pltO + do.call(theme, append(iL, eval(plst[-1])))  # add the theme
      } else if(thats != 'aes') {                      # everything else except aes()
        pltO <<- pltO + eval(pltList[[k]])
      }
    }
  }) 
  pltO <- pltO + facet_wrap(~grp, nrow = 1)

  # using viewports to reassemble
  print(nplt + theme(plot.margin = margin(b = 0, l = 60, r = 0)))
  pushViewport(current.viewport())
  pushViewport(viewport(layout = grid.layout(5, 1)))
  pushViewport(viewport(layout.pos.row = 4:5, layout.pos.col = 1))
  grid.draw(ggplotGrob(pltO))
  popViewport(3)
  grabby <- grid.grab(wrap.grobs = T)
  grabby
}

#---------- with your code as in your question ---------
ct <- ctree(Species ~ ., data = iris)

panel_prop <- function(count, panel) {
  count / tapply(count, panel, sum)[as.character(panel)]
}
# original plot
gg <- ggparty(ct) +
  geom_edge() +
  geom_edge_label(colour = "gray9", size = 3) +
  geom_node_plot(
    scales = "fixed", shared_axis_labels = FALSE,
    gglist = list(aes(
      y = Species, x = after_stat(panel_prop(count, PANEL)),
      fill = Species,
      label = after_stat(scales::percent(panel_prop(count, PANEL),
                                         accuracy = 1))),
      geom_bar(), 
      geom_text(stat = "count", hjust = 0, size = 3),
      coord_cartesian(clip = "off"),
      scale_x_continuous(labels = scales::percent_format(accuracy = 1),
                         expand = expansion(mult = c(.05, .25))),
      theme(axis.text.x = element_text(size = 5)),
      xlab(""), ylab(""))) +
  geom_node_label(
    aes(),
    line_list = list(aes(label = paste("Node", id)),
                     aes(label = splitvar),
                     aes(label = "")),  
    line_gpar = list(list(size = 8, col = "black"),  
                     list(size = 6), list(size = 8)),  
    ids = "inner") 

#----------- apply modifications with fixer() -------------
gg2 <- fixer(gg)  # call the plot (will show in the plot pane without calling gg2)

# to reprint to the plot pane
grid.newpage()
grid.draw(gg2)

# alternatively -- this combines the two previous calls
cowplot::ggdraw(gg2)

Upvotes: 1

Related Questions