R/zPlot.R

Defines functions pdpZenplot

Documented in pdpZenplot

#' Create a zenplot displaying partial dependence values.
#'
#' @description Constructs a zigzag expanded navigation plot (zenplot) displaying partial dependence values.
#'
#' @param task Task created from the mlr3 package, either regression or classification.
#' @param model A machine learning model created from mlr3 task and learner.
#' @param zpath A zenpath created from calcZpath. see \code{\link[zenplots]{zenpath}} from the
#' \code{\link[zenplots]{zenplots}} package for more details.
#' @param method "pdp" (default) or "ale"
#' @param noCols  number of columns of 2d plots (>= 1) or one of "letter", "square", "A4", "golden" or "legal"
#' in which case a similar layout is constructed. See ?zenplot
#' @param zenMethod String indicating the layout of the zigzag plot. The available methods are:
#' "tidy": more tidied-up double.zigzag (slightly more compact placement of plots towards the end).
#' "double.zigzag": zigzag plot in the form of a flipped “S”. Along this path, the plots are placed in the form of an “S” which is rotated counterclockwise by 90 degrees.
#' "single.zigzag": zigzag plot in the form of a flipped “S”.
#' "rectangular": plots that fill the page from left to right and top to bottom. This is useful (and most compact) for plots that do not share an axis.
#' @param pal A vector of colors to show predictions, for use with scale_fill_gradientn
#' @param fitlims Specifies the fit range for the color map. Options are a numeric vector of length 2,
#'  "pdp" (default), in which cases limits are calculated from the pdp.
#'  Predictions outside fitlims are squished on the color scale.
#' @param gridSize for the pdp/ale plots, defaults to 10.
#' @param nmax Maximum number of data rows to consider,  for calculating pdp.
#' @param class For a classification model, show the probability of this class. Defaults to 1.
#'
#' @return A zenplot of partial dependence values.
#'
#' @importFrom zenplots "zenplot"
#' @importFrom zenplots "indexData"
#' @importFrom zenplots "groupData"
#' @importFrom iml "FeatureEffect"
#' @importFrom iml "Predictor"
#'
#' @examples
#' # Load in the data:
#' aq <- na.omit(airquality)*1.0
#'
#' # Run an mlr3 ranger model:
#' library(mlr3)
#' library(mlr3learners)
#' library(ranger)
#' ozonet  <- TaskRegr$new(id = "airQ", backend = aq, target = "Ozone")
#' ozonel  <- lrn("regr.ranger", importance = "permutation")
#' ozonef  <- ozonel$train(ozonet)
#'
#' # Create matrix
#' viv <- vividMatrix(ozonet, ozonef)
#'
#'# Calculate Zpath:
#' zpath<-calcZpath(viv,.8)
#' zpath
#'
#' # Create graph:
#' pdpZenplot(ozonet, ozonef, zpath=zpath)
#'
#' @export



pdpZenplot <- function(task, model, zpath=NULL, method = "pdp",
                       noCols = c("letter", "square", "A4", "golden", "legal"),
                       zenMethod = c("tidy", "double.zigzag", "single.zigzag", "rectangular"),
                       pal=rev(RColorBrewer::brewer.pal(11,"RdYlBu")),
                       fitlims = "pdp", gridSize = 10, nmax=500,class = 1,...){

  prob <- model$task_type == "classif"
  data <-  task$data()
  data <- as.data.frame(data)
  target <- task$target_names

  if (nmax < nrow(data)){
    data <- data[sample(nrow(data), nmax), , drop = FALSE]
  }

  # make iml model
  if (prob){
    pred.data <- Predictor$new(model, data = data, class = class, y = target)
  }
  else{
    pred.data <- Predictor$new(model, data = data, y = target)
  }

  # Get data for pairs of variables
  xdata <- pred.data$data$get.x()

  if (is.null(zpath)) {
    zpath <- 1:ncol(xdata)
    zdata <- xdata
    zpairs <- t(sapply(1:(length(zpath)-1), function(i){
      z <- zpath[i:(i+1)]
      if (i %% 2 == 0) rev(z) else z
    }))
  }
  else if (is.character(zpath)){
    zpath <- match(zpath, names(xdata))
    if (any(is.na(zpath))) stop("'zpath' should contain predictor names.")
    zdata <- indexData(xdata, zpath)
    zpairs <- t(sapply(1:(length(zpath)-1), function(i){
      z <- zpath[i:(i+1)]
      if (i %% 2 == 0) rev(z) else z
    }))
  }
  else if (is.list(zpath)){
    zpath0 <- unlist(zpath)
    zpath0 <- match(zpath0, names(xdata))
    if (any(is.na(zpath0))) stop("'zpath' should contain predictor names.")
    zpath <- lapply(zpath, function(z) match(z, names(xdata)))
    zpairs <- t(sapply(1:(length(zpath0)-1), function(i){
      z <- zpath0[i:(i+1)]
      if (i %% 2 == 0) rev(z) else z
    }))
    fixind <- cumsum(sapply(zpath, length))
    fixind <- fixind[-length(fixind)]
    for (i in fixind) zpairs[i,]<- NA
    zdata <- groupData(xdata, indices = zpath)
  }


  # Create progress bar
  pb1 <- progress_bar$new(
    format = "  Calculating partial dependence...[:bar]:percent. Est::eta ",
    total = nrow(zpairs),
    clear = FALSE)

  # loop through vars and create a list of pdps for each pair
  pdplist <- vector("list", nrow(zpairs))

  for (i in 1:nrow(zpairs)){
    ind <- zpairs[i,]

    if (!is.na(ind[1]))
      p <-FeatureEffect$new(pred.data, ind, method = "pdp", grid.size=gridSize)
    else p <- NULL
    pdplist[[i]] <- list(index=ind, pdp=p)
    pb1$tick()
  }



  # Set limits for pairs
  if (fitlims=="pdp"){
    pdplist0 <- lapply(pdplist, function(x) x$pdp)
    pdplist0 <-pdplist0[!sapply(pdplist0, is.null)]
    r <- sapply(pdplist0, function(x) range(x$results[,3]))
    r <- range(r)
    limits <- range(labeling::rpretty(r[1],r[2]))
  } else
    limits <- fitlims


  # Zenplot graphing function
  z2index <- 0
  ggplot2d <- function(zargs) {

    z2index <<- z2index+1

    pdp <- pdplist[[z2index]]$pdp
    if (!is.null(pdp)) {
      p <- plot(pdp, rug=T) +
        scale_fill_gradientn(name = "\u0177",colors = pal, limits = limits, oob=scales::squish)+
        guides(fill=FALSE, color=FALSE) +
        theme_bw() +
        theme(axis.line = element_blank(),
              axis.ticks = element_blank(),
              axis.text.x = element_blank(),
              axis.text.y = element_blank(),
              axis.title.x = element_blank(),
              axis.title.y = element_blank(),
              panel.border = element_rect(colour = "gray", fill=NA, size = 1.5))

    }
    else p <- ggplot() + theme(panel.background = element_blank())

    ggplot_gtable(ggplot_build(p))
  }
  suppressMessages({
    zenplot(zdata, pkg="grid", labs=list(group=NULL),
            plot2d = function(zargs) ggplot2d(zargs), ...)
  })

}



# -------------------------------------------------------------------------


# pdpZenplot <- function(task, model, zpath=NULL, method = "pdp",
#                        noCols = c("letter", "square", "A4", "golden", "legal"),
#                        zenMethod = c("tidy", "double.zigzag", "single.zigzag", "rectangular"),
#                        pal=rev(RColorBrewer::brewer.pal(11,"RdYlBu")),
#                        fitlims = NULL, gridsize = 10, class = 1,...){
#
#   prob <- model$task_type == "classif"
#   data <-  task$data()
#   data <- as.data.frame(data)
#   target <- task$target_names
#
#
#   # make iml model
#   if (prob){
#     pred.data <- Predictor$new(model, data = data, class = class, y = target)
#   }
#   else{
#     pred.data <- Predictor$new(model, data = data, y = target)
#   }
#
#   # Get data for pairs of variables
#   xdata <- pred.data$data$get.x()
#
#   if (is.null(zpath)) {
#     zpath <- 1:ncol(xdata)
#     zdata <- xdata
#     zpairs <- t(sapply(1:(length(zpath)-1), function(i){
#       z <- zpath[i:(i+1)]
#       if (i %% 2 == 0) rev(z) else z
#     }))
#   }
#   else if (is.character(zpath)){
#     zpath <- match(zpath, names(xdata))
#     if (any(is.na(zpath))) stop("'zpath' should contain predictor names.")
#     zdata <- indexData(xdata, zpath)
#     zpairs <- t(sapply(1:(length(zpath)-1), function(i){
#       z <- zpath[i:(i+1)]
#       if (i %% 2 == 0) rev(z) else z
#     }))
#   }
#   else if (is.list(zpath)){
#     zpath0 <- unlist(zpath)
#     zpath0 <- match(zpath0, names(xdata))
#     if (any(is.na(zpath0))) stop("'zpath' should contain predictor names.")
#     zpath <- lapply(zpath, function(z) match(z, names(xdata)))
#     zpairs <- t(sapply(1:(length(zpath0)-1), function(i){
#       z <- zpath0[i:(i+1)]
#       if (i %% 2 == 0) rev(z) else z
#     }))
#     fixind <- cumsum(sapply(zpath, length))
#     fixind <- fixind[-length(fixind)]
#     for (i in fixind) zpairs[i,]<- NA
#     zdata <- groupData(xdata, indices = zpath)
#   }
#
#
#   # Create progress bar
#   pb1 <- progress_bar$new(
#     format = "  Calculating partial dependence...[:bar]:percent. Est::eta ",
#     total = nrow(zpairs),
#     clear = FALSE)
#
#   # loop through vars and create a list of pdps for each pair
#   pdplist <- vector("list", nrow(zpairs))
#
#   for (i in 1:nrow(zpairs)){
#     ind <- zpairs[i,]
#
#     if (!is.na(ind[1]))
#       p <-FeatureEffect$new(pred.data, ind, method = "pdp", grid.size=gridsize)
#     else p <- NULL
#     pdplist[[i]] <- list(index=ind, pdp=p)
#     pb1$tick()
#   }
#
#
#
#   # get predictions
#   Pred <- pred.data$predict(data)
#   colnames(Pred) <- "prd"
#   Pred <- Pred$prd
#
#
#   # Set limits for pairs
#   if (is.null(fitlims)){
#     pdplist0 <- lapply(pdplist, function(x) x$pdp)
#     pdplist0 <-pdplist0[!sapply(pdplist0, is.null)]
#     r <- sapply(pdplist0, function(x) range(x$results[,3]))
#     r <- range(c(r))
#     #r <- range(c(r,Pred))
#     limits <- range(labeling::rpretty(r[1],r[2]))
#   } else
#     limits <- fitlims
#
#
#   # Zenplot graphing function
#   z2index <- 0
#   ggplot2d <- function(zargs) {
#
#     z2index <<- z2index+1
#
#     pdp <- pdplist[[z2index]]$pdp
#     if (!is.null(pdp)) {
#       p <- plot(pdp, rug = TRUE) +
#         scale_fill_gradientn(name = "\u0177",colors = pal, limits = limits)+
#         guides(fill = FALSE, color = FALSE) +
#         theme_bw() +
#         theme(axis.line = element_blank(),
#               axis.ticks = element_blank(),
#               axis.text.x = element_blank(),
#               axis.text.y = element_blank(),
#               axis.title.x = element_blank(),
#               axis.title.y = element_blank(),
#               panel.border = element_rect(colour = "gray", fill=NA, size = 1.5))
#
#       pp <- plot(pdp, rug = TRUE) +
#         scale_fill_gradientn(name = "\u0177",colors = pal, limits = limits)+
#         guides(fill = "colourbar", color = FALSE) +
#         theme_bw() +
#         theme(axis.line = element_blank(),
#               axis.ticks = element_blank(),
#               axis.text.x = element_blank(),
#               axis.text.y = element_blank(),
#               axis.title.x = element_blank(),
#               axis.title.y = element_blank(),
#               panel.border = element_rect(colour = "gray", fill=NA, size = 1.5))
#
#       # Grab the legends using cowplot::get_legend()
#       p1_legend <- get_legend(p)
#
#
#       #endplot <- gridExtra::grid.arrange(p, p1_legend)
#
#       #endplot <- plot_grid(p, p1,
#       #                     align = "h",
#       #                     scale = c(1, 0.8),
#       #                     labels = c('A', 'B'),
#       #                     rel_widths = c(0.9, 0.1))
#
#     }
#     else p <- ggplot() + theme(panel.background = element_blank())
#
#     ggplot_gtable(ggplot_build(p))
#   }
#
#
#
#   suppressMessages({
#     zen.p <- zenplot(zdata, pkg="grid", labs=list(group=NULL),
#                      plot2d = function(zargs) ggplot2d(zargs), ...)
#     zenp.len <- plot_grid(zen.p, p1_legend)
#     print(zen.p)
#   })
#
# }
#
AlanInglis/vividOld documentation built on March 4, 2021, 12:44 a.m.