dev/style_node.R

library(ctree)
library(party)
library(dplyr)
library(xlsx)

# Download the data directly from the journal
fileURL <- "http://s3-eu-west-1.amazonaws.com/files.figshare.com/2149369/S1_File.xlsx"
# download the excel file as a binary file otherwise it won't read (at least in Windows)
destfile <- tempfile(fileext = ".xlsx")
download.file(fileURL,destfile, mode="wb")
# read the first and only sheet in the file
OzRodents<- read.xlsx(destfile,1)
OzRodents <- tbl_df(OzRodents)
OzRodents
# rename the relevant columns
OzRodents <- OzRodents %>% rename(North.South=NS,
                                  Body.mass=fmass,
                                  Habitat=habopenness)
# manipulate the data to match the labels in the figure
# change decline category into a binary outcome
OzRodents <- OzRodents%>% mutate(declinecategory=ifelse(declinecategory != "none",1, 0))
OzRodents$declinecategory <- as.factor(OzRodents$declinecategory)
# change NorthSouth into factor
OzRodents <- OzRodents%>% mutate(North.South=ifelse(North.South != 1,"South", "North"))
OzRodents$North.South <- as.factor(OzRodents$North.South)
# change Habitat.openness into open vs. forest (based on the data and as described in the methods section)
OzRodents <- OzRodents%>% mutate(Habitat=ifelse(Habitat > 2,"forest", "open"))
OzRodents$Habitat <- as.factor(OzRodents$Habitat)


# recursive partitioning
# run ctree model
rodCT <- ctree(declinecategory ~ North.South+Body.mass+Habitat,data = OzRodents,
               controls=ctree_control(testtype="Teststatistic"))
# print ctree object
rodCT
# plot tree with default settings
plot(rodCT)

# custom barplot function to alter the appearance of terminal nodes
# modified from code provided by Achim Zeileis to the R-help mailing list
# source the function and assign it as a grapcon_generator object
node_barplot2 <- function(ctreeobj,
                          col = "black",
                          fill = c("red", "white"),
                          beside = NULL,
                          ymax = NULL,
                          ylines = NULL,
                          widths = 1,
                          gap = NULL,
                          reverse = NULL,
                          id = TRUE)
{
  getMaxPred <- function(x) {
    mp <- max(x$prediction)
    mpl <- ifelse(x$terminal, 0, getMaxPred(x$left))
    mpr <- ifelse(x$terminal, 0, getMaxPred(x$right))
    return(max(c(mp, mpl, mpr)))
  }

  y <- response(ctreeobj)[[1]]

  if(is.factor(y) || class(y) == "was_ordered") {
    ylevels <- levels(y)
    if(is.null(beside)) beside <- if(length(ylevels) < 3) FALSE else TRUE
    if(is.null(ymax)) ymax <- if(beside) 1.1 else 1
    if(is.null(gap)) gap <- if(beside) 0.1 else 0
  } else {
    if(is.null(beside)) beside <- FALSE
    if(is.null(ymax)) ymax <- getMaxPred(ctreeobj @ tree) * 1.1
    ylevels <- seq(along = ctreeobj @ tree$prediction)
    if(length(ylevels) < 2) ylevels <- ""
    if(is.null(gap)) gap <- 1
  }
  if(is.null(reverse)) reverse <- !beside
  if(is.null(fill)) fill <- gray.colors(length(ylevels))
  if(is.null(ylines)) ylines <- if(beside) c(3, 4) else c(1.5, 2.5)

  ### panel function for barplots in nodes
  rval <- function(node) {

    ## parameter setup
    pred <- node$prediction
    if(reverse) {
      pred <- rev(pred)
      ylevels <- rev(ylevels)
    }
    np <- length(pred)
    nc <- if(beside) np else 1

    fill <- rep(fill, length.out = np)
    widths <- rep(widths, length.out = nc)
    col <- rep(col, length.out = nc)
    ylines <- rep(ylines, length.out = 2)

    gap <- gap * sum(widths)
    yscale <- c(0, ymax)
    xscale <- c(0, sum(widths) + (nc+1)*gap)

    top_vp <- viewport(layout = grid.layout(nrow = 2, ncol = 3,
                                            widths = unit(c(ylines[1], 1, ylines[2]), c("lines", "null", "lines")),
                                            heights = unit(c(3.5,3.5), c("lines", "null"))),
                       width = unit(1, "npc"),
                       height = unit(1, "npc") - unit(2, "lines"),
                       name = paste("node_barplot", node$nodeID, sep = ""))

    pushViewport(top_vp)
    grid.rect(gp = gpar(fill = "white", col = 0))

    ## main title
    top <- viewport(layout.pos.col=2, layout.pos.row=1)
    pushViewport(top)
    mainlab <- paste(ifelse(id, paste("Node", node$nodeID,"\n", "(n = "), "n = "),
                     sum(node$weights), ifelse(id, ")", ""), sep = "")
    grid.text(mainlab)
    popViewport()

    plot <- viewport(layout.pos.col=2, layout.pos.row=2,
                     xscale=xscale, yscale=yscale,
                     name = paste("node_barplot", node$nodeID, "plot",
                                  sep = ""))

    pushViewport(plot)

    if(beside) {
      xcenter <- cumsum(widths+gap) - widths/2
      for (i in 1:np) {
        grid.rect(x = xcenter[i], y = 0, height = pred[i],
                  width = widths[i],
                  just = c("center", "bottom"), default.units = "native",
                  gp = gpar(col = col[i], fill = fill[i]))
      }
      if(length(xcenter) > 1) grid.xaxis(at = xcenter, label = FALSE)
      grid.text(ylevels, x = xcenter, y = unit(-1, "lines"),
                default.units = "native",
                just= c("center","bottom"),
                check.overlap = TRUE)
      grid.yaxis()
    } else {
      ycenter <- cumsum(pred) - pred

      for (i in 1:np) {
        grid.rect(x = xscale[2]/2, y = ycenter[i], height = min(pred[i], ymax - ycenter[i]),
                  width = widths[1],
                  just = c("center", "bottom"), default.units = "native",
                  gp = gpar(col = col[i], fill = fill[i]))
      }

      grid.yaxis(at = round(1 - pred[i], digits = 2), main = FALSE)
    }

    grid.rect(gp = gpar(fill = "transparent"))
    upViewport(2)
  }

  return(rval)
}
class(node_barplot2) <- "grapcon_generator"

# custom function by user "agstudy"
# draws a white circle with the node name and the number of obs.
innerWeights <- function(node){
  grid.circle(r=0.36,gp = gpar(fill = "white",col="gray"))
  mainlab <- paste(node$psplit$variableName, "\n(n = ")
  mainlab <- paste(mainlab, sum(node$weights),")" , sep = "")
  grid.text(mainlab,gp = gpar(col='black'))
}

plot (rodCT,
      inner_panel=innerWeights,
      terminal_panel =
        node_barplot2(
          rodCT,
          col = "blue",
          fill = c("darkred", "gray"),
          reverse = TRUE
          ),
      tp_args = list(ylines = c(2, 4))) # this arg. modifies the spacing between barplots
jbkunst/treekir documentation built on May 3, 2019, 5:18 p.m.