R/partials.R

Defines functions partial

Documented in partial

#' @title Better, nicer, friendlier partial dependence plots
#'
#' @description
#'
#' Partial dependence plots show the response curves of an individual variable in the sum-of-trees models. The main line is the average of partial dependence plots for each posterior draw of sum-of-trees models; each of those curves is generated by evaluating the BART model prediction at each specified x value for *each other combination of other x values in the data*. This is obviously computationally very expensive, and gets slower to run depending on: how much smooth you add, how many variables you ask for, and more posterior draws (ndpost; defaults to 1000) in the bart() function.
#'
#' @param model A dbarts model object
#' @param x.vars A list of the variables for which you want to run the partials. Defaults to doing all of them.
#' @param equal Spacing x levels equally instead of using quantiles, which is how dbarts does this normally (the distribution of points reflects the distribution of samples in the data - this makes weird patterns that don't look very smooth)
#' @param smooth A multiplier for how much smoother you want the sampling of the levels to be. High values, like 10 or over, are obviously much slower and don't add much.
#' @param ci Plot a given \% credible interval with a blue bar. Defaults to 95\% and controlled by ciwidth 
#' @param ciwidth Specify the width of the plotted credible issue
#' @param trace Traceplots for each individual draw from the posterior
#' @param transform This converts from the logit output of dbarts:::predict to actual 0 to 1 probabilities. I wouldn't turn this off unless you're really interested in a deep dive on the model.
#' @param panels For multiple variables, use this to create a multipanel figure. 
#'
#'
#' @return Returns a ggplot object or cowplot object.
#'
#' @examples
#' f <- function(x) { return(0.5 * x[,1] + 2 * x[,2] * x[,3]) - 5*x[,4] }
#' sigma <- 0.2
#' n <- 100
#' x <- matrix(2 * runif(n * 3) -1, ncol = 3)
#' x <- data.frame(x)
#' x[,4] <- rbinom(100, 1, 0.3)
#' colnames(x) <- c('rob', 'hugh', 'ed', 'phil')
#' Ey <- f(x)
#' y  <- rnorm(n, Ey, sigma)
#' df <- data.frame(y, x)
#' set.seed(99)
#' 
#' bartFit <- bart(y ~ rob + hugh + ed + phil, df,
#'                keepevery = 10, ntree = 100, keeptrees = TRUE)
#'
#' partial(bartFit, x.vars='hugh', trace=TRUE, ci=TRUE)
#' partial(bartFit, x.vars='hugh', equal=TRUE, trace=TRUE, ci=TRUE)
#' partial(bartFit, x.vars='hugh', equal=TRUE, smooth=10, trace=TRUE, ci=TRUE)
#' 
#' partial(bartFit, x.vars='rob', equal=TRUE, smooth=10, trace=FALSE, ci=TRUE)
#' partial(bartFit, x.vars='ed', equal=TRUE, smooth=10, trace=TRUE, ci=FALSE)
#' partial(bartFit, equal=TRUE, smooth=10, trace=FALSE, ci=TRUE, panels=TRUE)
#'
#' @export
#'
#'

partial <- function(model, x.vars=NULL, equal=TRUE, smooth=1,
                    ci=TRUE, ciwidth=0.95, trace=TRUE,
                    transform=TRUE, panels=FALSE) {
  
  # A couple errors in case I'm Idiot
  
  if(smooth>10) {
    warning("You have chosen way, way too much smoothing... poorly")
  }
  
  if(!is.null(x.vars) && length(x.vars)==1 && panels==TRUE) {
    stop("Hey bud, you can't do several panels on only one variable!")
  }
  
# This is for something else ultimately: attr(bartFit$fit$data@x, "term.labels")
# This is where equal happens
  
  
  if(class(model)=='rbart') {
    fitobj <- model$fit[[1]]
  }
  if(class(model)=='bart') {
    fitobj <- model$fit
  }
  
  if (is.null(x.vars)) { raw <- fitobj$data@x} else { raw <- fitobj$data@x[,x.vars]}
  
  if(equal==TRUE) {
    if(!is.null(x.vars) && length(x.vars)==1) {
          minmax <- data.frame(mins = min(raw),
                               maxs = max(raw)) } else {
          minmax <- data.frame(mins = apply(raw, 2, min),
                         maxs = apply(raw, 2, max))
    }
    lev <- lapply(c(1:nrow(minmax)), function(i) {seq(minmax$mins[i], minmax$maxs[i], (minmax$maxs[i]-minmax$mins[i])/(10*smooth))})
    
    for(i in 1:length(lev)){
      if(length(lev)==1) {  
        if(length(unique(raw))==2) { lev[[i]] <- unique(raw) }
      } else {
        if(length(unique(raw[,i]))==2) { lev[[i]] <- unique(raw[,i])}
      }
    }
    
    pd <- pdbart(model, xind = x.vars, levs = lev, pl=FALSE)
  } else {
    levq = c(0.5 - ciwidth/2, seq(0.1, 0.9, 0.1/smooth), 0.5 + ciwidth/2)
    pd <- pdbart(model, xind = x.vars, levquants = levq, pl=FALSE)
  }
 
  
# This is the wrapper itself

plots <- list()  

for (i in 1:length(pd$fd)) {
  
  if(length(unique(pd$fit$data@x[,pd$xlbs[[i]]]))==2) {
    
    dfbin <- data.frame(pd$fd[[i]])
    colnames(dfbin) <- c(0,1)
    dfbin <- reshape2::melt(dfbin)
    
    if(transform==TRUE){
      dfbin$value <- pnorm(dfbin$value)
    }
  
    if(ci==FALSE) {
    g <- ggplot(dfbin,aes(x=variable, y=value)) + geom_boxplot() + 
      labs(title=pd$xlbs[[i]], y='Response',x='') + theme_light(base_size = 20) + 
      theme(plot.title = element_text(hjust = 0.5),
            axis.title.y = element_text(vjust=1.7))
    } else {
      g <- ggplot(dfbin,aes(x=variable, y=value)) + geom_boxplot( fill='deepskyblue1') + 
        labs(title=pd$xlbs[[i]], y='Response',x='') + theme_light(base_size = 20) + 
        theme(plot.title = element_text(hjust = 0.5),
              axis.title.y = element_text(vjust=1.7)) 
    }
    
    if(panels==FALSE) {g <- g + theme(plot.margin=unit(c(0.5,0.5,0.5,0.5),"cm"))}else {
      g <- g + theme(plot.margin=unit(c(0.15,0.15,0.15,0.15),"cm"))}
    plots[[i]] <- g
    
  } else {
    
  q50 <- apply(pd$fd[[i]],2,median)
  if(transform==TRUE) {q50 <- pnorm(q50)}
  
  df <- data.frame(x=pd$levs[[i]],med=q50)

  if(ci==TRUE) {
    q05 <- apply(pd$fd[[i]],2,quantile,probs=0.5 - ciwidth/2)
    if(transform==TRUE) {q05 <- pnorm(q05)}
    q95 <- apply(pd$fd[[i]],2,quantile,probs=0.5 + ciwidth/2)
    if(transform==TRUE) {q95 <- pnorm(q95)}
    df$q05 <- q05
    df$q95 <- q95
  }
  
  if(trace==TRUE) {
    f <- data.frame(t(pd$fd[[i]]))
    df <- cbind(df, f)
  }
      
  g <- ggplot(df,aes(x=x, y=med)) + 
    labs(title=pd$xlbs[[i]], y='Response',x='') + theme_light(base_size = 20) + 
    theme(plot.title = element_text(hjust = 0.5),
          axis.title.y = element_text(vjust=1.7))
  
  if(ci==TRUE) {alpha2 <- 0.05; k <- 4} else {alpha2 <- 0.025*(fitobj$control@n.trees/200); k <- 2}
  if(trace==TRUE) {
    if(transform==TRUE) {
      for(j in 1:nrow(pd$fd[[i]])) {
        g <- g + geom_line(aes_string(y=pnorm(df[,j+k])), alpha=alpha2)
      }
    } else {
      for(j in 1:nrow(pd$fd[[i]])) {
        g <- g + geom_line(aes_string(y=df[,j+k]), alpha=alpha2)
      }
    }
  }
  
  if(ci==TRUE) {
    g <- g + geom_ribbon(aes(ymin=q05, ymax=q95), fill='deepskyblue1', alpha=0.3)
  }
  
  g <- g + geom_line(size=1.25)
  
  if(panels==FALSE) {g <- g + theme(plot.margin=unit(c(0.5,0.5,0.5,0.5),"cm"))} else {
    g <- g + theme(plot.margin=unit(c(0.15,0.15,0.15,0.15),"cm"))
  }
  plots[[i]] <- g

  }
}
  
if(panels==TRUE) {#print(cowplot::plot_grid(plotlist=plots))
  return(wrap_plots(plotlist=plots)) 
} else {
  return(plots)
}
}
cjcarlson/embarcadero documentation built on Sept. 9, 2023, 10:47 p.m.