R/plot_ATE.R

Defines functions plot_SATE plot_PATE plot_ICATE plot_CATE

Documented in plot_CATE plot_ICATE plot_PATE plot_SATE

#' @title Plot the histogram or density of the Conditional Average Treatment Effect
#' @description Plot the conditional average treatment effect (CATE) of a 'bartCause' model.
#' The conditional average treatment effect is derived from taking the difference between
#' predictions for each individual under the control condition and under the treatment condition averaged over the population.
#' Means of the CATE distribution will resemble SATE and PATE but the CATE distribution accounts for more uncertainty than SATE and less uncertainty than PATE.
#'
#'
#' @param .model a model produced by `bartCause::bartc()`
#' @param type histogram or density
#' @param ci_80 TRUE/FALSE. Show the 80\% credible interval?
#' @param ci_95 TRUE/FALSE. Show the 95\% credible interval?
#' @param reference numeric. Show a vertical reference line at this value
#' @param .mean TRUE/FALSE. Show the mean reference line
#' @param .median TRUE/FALSE. Show the median reference line
#'
#'
#' @author George Perrett, Joseph Marlo
#'
#' @return ggplot object
#' @export
#'
#' @import ggplot2 bartCause
#' @examples
#' \donttest{
#' data(lalonde)
#' confounders <- c('age', 'educ', 'black', 'hisp', 'married', 'nodegr')
#' model_results <- bartCause::bartc(
#'  response = lalonde[['re78']],
#'  treatment = lalonde[['treat']],
#'  confounders = as.matrix(lalonde[, confounders]),
#'  estimand = 'ate',
#'  commonSup.rule = 'none'
#' )
#' plot_CATE(model_results)
#' }
plot_CATE <- function(.model, type = c('histogram', 'density'), ci_80 = FALSE, ci_95 = FALSE, reference = NULL, .mean = FALSE, .median = FALSE){

  validate_model_(.model)
  type <- tolower(type[1])
  if (type %notin% c('histogram', 'density')) stop("type must be 'histogram' or 'density'")

  # set title
  .title <- switch(
    .model$estimand,
    ate = "Posterior of Average Treatment Effect",
    att = "Posterior of Average Treatment Effect of the Treated",
    atc = "Posterior of Average Treatment Effect of the Control"
  )

  # calculate stats
  cate <- bartCause::extract(.model, 'cate')
  cate <- as.data.frame(cate)
  ub <- quantile(cate$cate, 0.9)
  lb <- quantile(cate$cate, 0.1)
  ub.95 <- quantile(cate$cate, 0.975)
  lb.95 <- quantile(cate$cate, 0.025)
  dd <- density(cate$cate)
  dd <- with(dd, data.frame(x, y))

  # build base plot
  p <- ggplot(cate, aes(cate)) +
    scale_linetype_manual(values = c(2, 3)) +
    theme(legend.title = element_blank()) +
    labs(title = .title,
         x = toupper(.model$estimand))

  # histogram
  if (type == 'histogram'){
    p <- p +
      geom_histogram(fill = 'grey60', ccolor = 'black') +
      labs(y = 'Frequency')

    # add credible intervals
    if (isTRUE(ci_95)) p <- p + geom_segment(x = lb.95, xend = ub.95, y = 0, yend = 0, size = 3, color = 'grey10')
    if (isTRUE(ci_80)) p <- p + geom_segment(x = lb, xend = ub, y = 0, yend = 0, size = 1.5, color = 'grey25')
  }

  # density
  if (type == 'density'){
    p <- p +
      geom_density() +
      labs(y = 'Density',
           linetype = NULL)

    # add credible intervals
    if (isTRUE(ci_95)){
      p <- p +
        geom_ribbon(data = subset(dd, x > lb.95 & x < ub.95),
                    aes(x = x, y = y, ymax = y),
                    ymin = 0, fill = "grey40", colour = NA, alpha = 0.8)
    }
    if (isTRUE(ci_80)){
      p <- p +
        geom_ribbon(data = subset(dd, x > lb & x < ub),
                    aes(x = x, y = y, ymax = y),
                    ymin = 0, fill = "grey30", colour = NA, alpha = 0.8)
    }
  }

  # add reference lines
  if (isTRUE(.mean)) p <- p + geom_vline(data = cate, aes(xintercept = mean(cate), linetype = 'mean'))
  if (isTRUE(.median)) p <- p + geom_vline(data = cate, aes(xintercept = median(cate), linetype = 'median'))
  if (!is.null(reference)) p <- p + geom_vline(xintercept = reference)

  return(p)
}


#' @title Plot Individual Conditional Average Treatment effects
#' @description Plots a histogram of Individual Conditional Average Treatment effects (ICATE).
#' ICATEs are the difference in each individual's predicted outcome under the treatment and predicted outcome under the control averaged over the individual.
#' Plots of ICATEs are useful to identify potential heterogeneous treatment effects between different individuals. ICATE plots can be grouped by discrete variables.
#'
#' @param .model a model produced by `bartCause::bartc()`
#' @param .group_by a grouping variable as a vector
#' @param n_bins number of bins
#' @param .alpha transparency of histograms
#'
#' @author George Perrett
#'
#' @return ggplot object
#' @export
#'
#' @import ggplot2 dplyr bartCause
#'
#' @examples
#' \donttest{
#' data(lalonde)
#' confounders <- c('age', 'educ', 'black', 'hisp', 'married', 'nodegr')
#' model_results <- bartCause::bartc(
#'  response = lalonde[['re78']],
#'  treatment = lalonde[['treat']],
#'  confounders = as.matrix(lalonde[, confounders]),
#'  estimand = 'ate',
#'  commonSup.rule = 'none'
#' )
#' plot_ICATE(model_results, lalonde$married)
#' }
plot_ICATE <- function(.model, .group_by = NULL, n_bins = 30, .alpha = .7){

  validate_model_(.model)
  if (!is.null(.group_by)) is_discrete_(.group_by)

  posterior <- bartCause::extract(.model, 'icate')
  icates <- as_tibble(apply(posterior, 2, mean))

  # adjust value based on estimand
  .group_by <- adjust_for_estimand_(.model, .group_by)

  # create base plot
  p <- ggplot(icates, aes(x = value)) +
    geom_histogram(bins = n_bins, color = 'black')

  # add grouping
  if(!is.null(.group_by)){
    p <- ggplot(data = icates,
                aes(x = value, fill = as.factor(.group_by))) +
      geom_histogram(position = 'identity', bins = n_bins, alpha = .alpha, col = 'black')
  }

  # add labels
  p <- p +
    labs(title = NULL,
         x = NULL,
         y = 'Count',
         fill = NULL)

  return(p)
}

#' @title Plot histogram or density of Population Average Treatment Effect
#' @description Plot shows the Population Average Treatment Effect which is derived from the posterior predictive distribution of the difference between \eqn{y | z=1, X} and \eqn{y | z=0, X}.
#' Mean of PATE will resemble CATE and SATE but PATE will account for more uncertainty and is recommended for informing inferences on the average treatment effect.
#'
#' @param .model a model produced by `bartCause::bartc()`
#' @param type histogram or density
#' @param ci_80 TRUE/FALSE. Show the 80\% credible interval?
#' @param ci_95 TRUE/FALSE. Show the 95\% credible interval?
#' @param reference numeric. Show a vertical reference line at this value
#' @param .mean TRUE/FALSE. Show the mean reference line
#' @param .median TRUE/FALSE. Show the median reference line
#'
#' @author George Perrett, Joseph Marlo
#'
#' @return ggplot object
#' @export
#'
#' @import ggplot2 bartCause
#' @examples
#' \donttest{
#' data(lalonde)
#' confounders <- c('age', 'educ', 'black', 'hisp', 'married', 'nodegr')
#' model_results <- bartCause::bartc(
#'  response = lalonde[['re78']],
#'  treatment = lalonde[['treat']],
#'  confounders = as.matrix(lalonde[, confounders]),
#'  estimand = 'ate',
#'  commonSup.rule = 'none'
#' )
#' plot_PATE(model_results)
#' }
plot_PATE <- function(.model, type = c('histogram', 'density'), ci_80 = FALSE, ci_95 = FALSE, reference = NULL, .mean = FALSE, .median = FALSE){
  validate_model_(.model)
  type <- tolower(type[1])
  if (type %notin% c('histogram', 'density')) stop("type must be 'histogram' or 'density'")

  # set title
  .title <- switch(
    .model$estimand,
    ate = "Posterior of Sample Average Treatment Effect",
    att = "Posterior of Sample Average Treatment Effect of the Treated",
    atc = "Posterior of Sample Average Treatment Effect of the Control"
  )

  # calculate stats
  y.1 <- extract(.model, 'y.1')
  y.0 <- extract(.model, 'y.0')
  pate.samples <- t(y.1 - y.0)

  # check overlap
  pate_overlap <- apply_overlap_rules(.model)

  # get different pates if warning is activeated
  pates <- tibble(none = apply(pate.samples, 2, mean))

  if (pate_overlap$sum_sd_removed > 0 | pate_overlap$sum_chisq_removed > 0) {
    pates$sd <- apply(pate.samples[!pate_overlap$ind_sd_removed,], 2, mean)
    pates$chisq <-apply(pate.samples[!pate_overlap$ind_chisq_removed,], 2, mean)
  }


  # pivot to long form now we just have name(what type of sate) and value
  pates <- pivot_longer(pates, cols = 1:length(pates))
  # caclulate bounds for each type of sate (no overlap, sd and chisq)
  ub <- tapply(pates$value, pates$name, function(i){quantile(i, .9)})
  lb <- tapply(pates$value, pates$name, function(i){quantile(i, .1)})
  lb.95 <- tapply(pates$value, pates$name, function(i){quantile(i, .025)})
  ub.95 <- tapply(pates$value, pates$name, function(i){quantile(i, .975)})

  # calculate densities and use bind_rows() to roll into a single df for ggplot
  dd <- tapply(pates$value, pates$name, density)
  dd <-
    lapply(1:length(dd), function(i) {
      data.frame(x = dd[[i]]$x,
                 y = dd[[i]]$y,
                 name = names(dd)[i],
                 lb.95 = lb.95[[names(dd)[i]]],
                 ub.95 = ub.95[[names(dd)[i]]],
                 lb = lb[[names(dd)[i]]],
                 ub = ub[[names(dd)[i]]])
    }) %>%
    bind_rows()


  # build base plot
  p <- ggplot(pates, aes(value)) +
    scale_linetype_manual(values = c(2, 3)) +
    theme(legend.title = element_blank()) +
    labs(title = .title,
         x = toupper(.model$estimand))

  # facet if removal rules would create different results
  if(length(unique(pates$name)) > 1){
    .facet_lab <- c(
      `none` = "No overlap rule applied: 0 cases (0%) were removed",
      `sd` = paste0("Standard deviation overlap rule applied: ", pate_overlap$sum_sd_removed, ' cases (',round((pate_overlap$sum_sd_removed/nrow(pate.samples)*100), 2) ,'%) were removed'),
      `chisq` = paste0("Chi-squard overlap rule applied: ", pate_overlap$sum_chisq_removed, ' cases (',round((pate_overlap$sum_chisq_removed/nrow(pate.samples)*100), 2) ,'%) were removed')
    )
    p <- p +
      facet_wrap(~ factor(name,
                          levels = c('none', 'sd', 'chisq')),
                 ncol = 1, labeller = as_labeller(.facet_lab))
  }
  # histogram
  if (type == 'histogram'){
    p <- p +
      geom_histogram(fill = 'grey60', color = 'black') +
      labs(y = 'Frequency')

    # add credible intervals
    if (isTRUE(ci_95)) p <- p + geom_segment(x = lb.95, xend = ub.95, y = 0, yend = 0, size = 3, color = 'grey10')
    if (isTRUE(ci_80)) p <- p + geom_segment(x = lb, xend = ub, y = 0, yend = 0, size = 1.5, color = 'grey25')
  }

  # density
  if (type == 'density'){
    p <- p +
      geom_density() +
      labs(y = 'Density')

    # add credible intervals
    if (isTRUE(ci_95)){
      p <- p +
        geom_ribbon(data = subset(dd, x > lb.95 & x < ub.95),
                    aes(x = x, y = y, ymax = y, group = name),
                    ymin = 0, fill = "grey40", colour = NA, alpha = 0.8)
    }
    if (isTRUE(ci_80)){
      p <- p +
        geom_ribbon(data = subset(dd, x > lb & x < ub),
                    aes(x = x, y = y, ymax = y, group = name),
                    ymin = 0, fill = "grey30", colour = NA, alpha = 0.8)
    }
  }

  # add reference lines
  if (isTRUE(.mean)) p <- p + geom_vline(data = pates, aes(xintercept = mean(pate), linetype = 'mean'))
  if (isTRUE(.median)) p <- p + geom_vline(data = pates, aes(xintercept = median(pate), linetype = 'median'))
  if (!is.null(reference)) p <- p + geom_vline(xintercept = reference)

  return(p)
}

#' @title Plot histogram or density of Sample Average Treatment Effects
#' @description Plot a histogram or density of the Sample Average Treatment Effect (SATE). The Sample Average Treatment Effect is derived from taking the difference of each individual's observed outcome and a predicted counterfactual outcome from a BART model averaged over the population.
#' The mean of SATE will resemble means of CATE and PATE but will account for the least uncertainty.
#'
#' @param .model a model produced by `bartCause::bartc()`
#' @param type histogram or density
#' @param ci_80 TRUE/FALSE. Show the 80\% credible interval?
#' @param ci_95 TRUE/FALSE. Show the 95\% credible interval?
#' @param reference numeric. Show a vertical reference line at this x-axis value
#' @param .mean TRUE/FALSE. Show the mean reference line
#' @param .median TRUE/FALSE. Show the median reference line
#' @param check_overlap TRUE/FALSE. Check if any overlap rules are applicable
#' @param overlap_rule enter overlap rules to view how different bartCause removal rules would have influenced results. Only applicable if check_overlap is TRUE.
#'
#' @author George Perrett, Joseph Marlo
#'
#' @return ggplot object
#' @export
#' @import ggplot2 bartCause
#' @examples
#' \donttest{
#' data(lalonde)
#' confounders <- c('age', 'educ', 'black', 'hisp', 'married', 'nodegr')
#' model_results <- bartCause::bartc(
#'  response = lalonde[['re78']],
#'  treatment = lalonde[['treat']],
#'  confounders = as.matrix(lalonde[, confounders]),
#'  estimand = 'ate',
#'  commonSup.rule = 'none'
#' )
#' plot_SATE(model_results)
#' }
plot_SATE <- function(.model, type = c('histogram', 'density'), ci_80 = FALSE, ci_95 = FALSE, reference = NULL, .mean = FALSE, .median = FALSE, check_overlap = FALSE, overlap_rule = c('none', 'sd', 'chisq')){

  validate_model_(.model)
  type <- tolower(type[1])

  if(isFALSE(check_overlap)){
    overlap_rule <- .model$commonSup.rule
  }

  if (type %notin% c('histogram', 'density')) stop("type must be 'histogram' or 'density'")
  if (sum(overlap_rule[order(overlap_rule)] %notin% c('chisq', 'none', 'sd')) > 0) stop("'none',  'se' and 'chisq' are only accepted overlap rules")

   # set title
  .title <- switch(
    .model$estimand,
    ate = "Posterior of Sample Average Treatment Effect",
    att = "Posterior of Sample Average Treatment Effect of the Treated",
    atc = "Posterior of Sample Average Treatment Effect of the Control"
  )

  .sign <- switch (.model$estimand,
    ate = (2 *.model$trt - 1),
    att = (2 *.model$trt[.model$trt == 1] - 1),
    atc = (2 *.model$trt[.model$trt == 0] - 1)
  )

  # calculate stats
  y_obs <- switch (.model$estimand,
                   ate = .model$data.rsp@y,
                   att = .model$data.rsp@y[.model$trt == 1],
                   atc = .model$data.rsp@y[.model$trt == 0]
  )

  y_cf <- extract(.model, 'y.cf')
  sate.samples <- (y_obs - t(y_cf))*.sign

  # check overlap
  sate_overlap <- apply_overlap_rules(.model)

  # get different sates
  sates <- tibble(
    none = apply(sate.samples, 2, mean),
    sd = apply(sate.samples[!sate_overlap$ind_sd_removed,], 2, mean),
    chisq = apply(sate.samples[!sate_overlap$ind_chisq_removed,], 2, mean)
  )

  sates <- sates[, overlap_rule]

  # pivot to long form now we just have name(what type of sate) and value
  sates <- pivot_longer(sates, cols = 1:length(sates))
  # caclulate bounds for each type of sate (no overlap, sd and chisq)
  ub <- tapply(sates$value, sates$name, function(i){quantile(i, .9)})
  lb <- tapply(sates$value, sates$name, function(i){quantile(i, .1)})
  lb.95 <- tapply(sates$value, sates$name, function(i){quantile(i, .025)})
  ub.95 <- tapply(sates$value, sates$name, function(i){quantile(i, .975)})

  # for cis and summaries
  sates <- sates %>%
    group_by(name) %>%
    mutate(ub = quantile(value, .9),
           lb = quantile(value, .1),
           ub.95 = quantile(value, .975),
           lb.95 = quantile(value, .025),
           sate_mean = mean(value),
           sate_median = median(value)) %>%
    ungroup()


  # calculate densities and use bind_rows() to roll into a single df for ggplot
  dd <- tapply(sates$value, sates$name, density)
  dd <-
    lapply(1:length(dd), function(i) {
      data.frame(x = dd[[i]]$x,
                 y = dd[[i]]$y,
                 name = names(dd)[i],
                 lb.95 = lb.95[[names(dd)[i]]],
                 ub.95 = ub.95[[names(dd)[i]]],
                 lb = lb[[names(dd)[i]]],
                 ub = ub[[names(dd)[i]]])
    }) %>%
    bind_rows()


  # build base plot
  p <- ggplot(sates, aes(value)) +
    scale_linetype_manual(values = c(2, 3)) +
    theme(legend.title = element_blank()) +
    labs(title = .title,
         x = toupper(.model$estimand))

  # apply overlap rules
  if(isTRUE(check_overlap)){
    if(.model$commonSup.rule %notin% overlap_rule){
      overlap_rule <- c(overlap_rule, .model$commonSup.rule)
    }
    # facet if removal rules would create different results
    .facet_lab <- vector()

    if('none' %in% overlap_rule){
      .facet_lab <- c(.facet_lab, `none` = "No overlap rule applied: 0 cases (0%) were removed")
    }

    if('sd' %in% overlap_rule){
      .facet_lab <- c(.facet_lab, `sd` = paste0("Standard deviation overlap rule applied: ", sate_overlap$sum_sd_removed, ' cases (',round((sate_overlap$sum_sd_removed/nrow(sate.samples)*100), 2) ,'%) were removed'))
    }

    if('chisq' %in% overlap_rule){
      .facet_lab <- c(.facet_lab, `chisq` = paste0("Chi-squard overlap rule applied: ", sate_overlap$sum_chisq_removed, ' cases (',round((sate_overlap$sum_chisq_removed/nrow(sate.samples)*100), 2) ,'%) were removed'))
    }

    p <- p +
      facet_wrap(~ factor(name,
                          levels = c('none', 'sd', 'chisq')),
                 ncol = 1, labeller = as_labeller(.facet_lab))
  }

  # histogram
  if (type == 'histogram'){
    p <- p +
      geom_histogram(fill = 'grey60', color = 'black') +
      labs(y = 'Frequency')

    # add credible intervals
    if (isTRUE(ci_80)) p <- p + geom_segment(aes(x = lb, xend = ub, y = 0, yend = 0), size = 3, color = 'grey10')
    if (isTRUE(ci_95)) p <- p + geom_segment(aes(x = lb.95, xend = ub.95, y = 0, yend = 0), size = 1.5, color = 'grey25')
  }

  # density
  if (type == 'density'){
    p <- p +
      geom_density() +
      labs(y = 'Density')

    # add credible intervals
    if (isTRUE(ci_95)){
      p <- p +
        geom_ribbon(data = subset(dd, x > lb.95 & x < ub.95),
                    aes(x = x, y = y, ymax = y, group = name),
                    ymin = 0, fill = "grey40", colour = NA, alpha = 0.8)
    }
    if (isTRUE(ci_80)){
      p <- p +
        geom_ribbon(data = subset(dd, x > lb & x < ub),
                    aes(x = x, y = y, ymax = y, group = name),
                    ymin = 0, fill = "grey30", colour = NA, alpha = 0.8)
    }
  }
  # add reference lines
  if (isTRUE(.mean)) p <- p + geom_vline(data = sates, aes(xintercept = sate_mean, linetype = 'mean'))
  if (isTRUE(.median)) p <- p + geom_vline(data = sates, aes(xintercept = sate_median, linetype = 'median'))
  if (!is.null(reference)) p <- p + geom_vline(xintercept = reference)

  return(p)
}
priism-center/plotBart documentation built on June 2, 2024, 8:50 a.m.