R/plot_balance.R

Defines functions plot_covariance print_covariance print_balance plot_balance

Documented in plot_balance plot_covariance print_balance print_covariance

#' @title Plot the balance
#' @description Visualize balance of variables between treatment and control groups. Balance plot reflects balance in standardized units.
#'
#' @param .data dataframe
#' @param treatment the column denoted treatment. Must be binary.
#' @param confounders character list of column names denoting the X columns of interest
#' @param compare character of either means or variance denotes what to compare balance on
#' @param estimand character of either ATE, ATT or ATC the causal estimand you are making inferences about
#' @param limit_continuous integer that can be used to limit the plot to only show the limit_continuous most imbalanced variables
#' @param limit_catagorical integer that can be used to limit the plot to only show the limit_categorical most imbalanced variables
#'
#' @author George Perrett & Joseph Marlo
#'
#' @return ggplot object
#' @export
#'
#' @import ggplot2 dplyr patchwork
#' @importFrom tidyr pivot_longer pivot_wider
#' @importFrom utils combn
#' @importFrom stats reorder
#' @examples
#' data(lalonde)

#' plot_balance(lalonde, 'treat', c('re78', 'age', 'educ'),
#' compare = 'means', estimand = 'ATE') +
#' labs(title = 'My new title')
#'

plot_balance <- function(.data, treatment, confounders, compare = c('means', 'variance', 'covariance'), estimand = c('ATE', 'ATT', 'ATC'), limit_continuous = NULL, limit_catagorical = NULL){
  if(missing(treatment)) stop('enter a string indicating the name of the treatment variable')
  if('factor' %in% sapply(.data[, confounders], class)) stop('factor variables must be converted to numeric or logical indicator variables')
  if('character' %in% sapply(.data[, confounders], class)) stop('factor variables must be converted to numeric or logical indicator variables')
  if(!is.null(limit_continuous)){
    if(limit_continuous != round(limit_continuous)) stop('limit_continuous must be a whole number that can be converted to numeric')
  }
  if(!is.null(limit_catagorical)){
    if(limit_catagorical != round(limit_catagorical)&!is.null(limit_catagorical)) stop('limit_catagorical must be a whole number that can be converted to numeric')
  }
  if (length(table(.data[[treatment]])) != 2) stop("treatment must be binary")
  .data[[treatment]] <- coerce_to_logical_(.data[[treatment]])

  # make sure arguments are set
  compare <- match.arg(compare)
  estimand <- match.arg(estimand)
  estimand <- toupper(estimand)

  if (estimand %notin% c('ATE', 'ATT', 'ATC')) stop("estimand must be either: ATE, ATT or ATC")
  if(compare == 'variance')x_var <- 'variance' else x_var <- 'means'

  # prep the df
  if(compare == 'covariance'){
    # gets interactions of all columns
    cov_dat <- combn(.data[, confounders], 2, FUN = Reduce, f = `*`)

    # get column names so we know what is what
    colnames(cov_dat) <- paste(
      combn(names(.data[, confounders]), 2)[1, ],
      combn(names(.data[, confounders]), 2)[2, ],
      sep = '*')
    # add back treatment
    .data <- cbind.data.frame(cov_dat, treatment = .data[[treatment]])

  }else{
    .data <- .data %>%
      dplyr::select(all_of(c(confounders, treatment))) %>%
      rename(`treatment` = treatment)

  }

  data_types <- apply(.data, 2, function(i) identical(unique(i)[order(unique(i))], c(0, 1))) %>%
    as.data.frame()
  data_types$name<- rownames(data_types)
  names(data_types)[1] <- 'type'

  data_types <- data_types %>%
    mutate(type = if_else(type, 'binary/categorical', 'continuous')) %>%
    arrange(name) %>%
    filter(name != 'treatment')

    # calculate and plot it
   .data <- .data %>%
    pivot_longer(cols = -treatment) %>%
    group_by(across(c('name', treatment))) %>%
    summarize(
      variance = var(value, na.rm = TRUE),
      mean = mean(value, na.rm = TRUE),
      .groups = 'drop') %>%
    pivot_wider(names_from = treatment, values_from = c(variance, mean)) %>%
    arrange(name)

    .data <- cbind.data.frame(.data, type = data_types$type)

   .data <- .data %>%
    mutate(means =
             case_when(
               estimand == 'ATE' ~ (mean_TRUE - mean_FALSE) / sqrt((variance_TRUE + variance_FALSE) /2),
               estimand == 'ATT' ~ (mean_TRUE - mean_FALSE) / sqrt(variance_TRUE),
               estimand == 'ATC' ~ (mean_TRUE - mean_FALSE) / sqrt(variance_FALSE)
             ),
           variance = sqrt(variance_TRUE/variance_FALSE),
           means = if_else(type != 'continuous', mean_TRUE - mean_FALSE, means)
           )

   # calculate the order before rounding
   if(compare == 'variance'){
     .data <- .data %>%
       mutate(order = if_else(variance < 1, 1/variance, variance))
   }else{
     .data <- .data %>%
       mutate(order = abs(means))
   }

   .data <- .data %>%
    mutate(flag_means = if_else(means > 2 | means < -2, 1, 0),
           flag_variance = if_else(variance > 4 | variance < .25, 1, 0),
           flag = if(compare == 'variance') flag_variance else  flag_means,
           means = if_else(means > 2, 2, means),
           means = if_else(means < -2, -2, means),
           variance = if_else(variance > 4, 4, variance),
           variance = if_else(variance < .25, .25, variance)) %>%
     na.omit()

   if(!is.null(limit_catagorical) ){
    .data <- .data %>%
      arrange(desc(order)) %>%
      group_by(type) %>%
      mutate(rank = row_number()) %>%
      filter(rank < limit_catagorical | type == 'continuous') %>%
      ungroup()
   }

   if(!is.null(limit_continuous)){
     .data <- .data %>%
       arrange(desc(order)) %>%
       group_by(type) %>%
       mutate(rank = row_number()) %>%
       filter(rank < limit_continuous | type != 'continuous') %>%
       ungroup()
   }

   p1 <- .data %>%
     filter(type != 'continuous') %>%
     ggplot(aes(
       x = get(x_var),
       y = reorder(name, order))) +
     geom_vline(
       xintercept = ifelse(x_var == 'means', 0, 1),
       linetype = 'dashed',
       color = 'gray60'
     ) +
     geom_point(size = 4) +
     labs(
       x = case_when(
         compare == 'means' ~ 'Mean difference (Proportion)',
         compare == 'variance' ~ 'Ratio of variance (log scale)',
         compare == 'covariance' ~ 'Scaled mean difference of interactions (balance of covaraince)'
       ),
       y = NULL,
       color = NULL
     ) +
     coord_cartesian(xlim = c(-1, 1)) +
     theme(legend.position = 'none')

   p2 <- .data %>%
     filter(type == 'continuous') %>%
     ggplot(aes(
       x = get(x_var),
       y = reorder(name, order),
       col = as.factor(flag)
     )) +
     geom_vline(
       xintercept = ifelse(x_var == 'means', 0, 1),
       linetype = 'dashed',
       color = 'gray60'
     ) +
     geom_point(size = 4) +
     scale_color_manual(values = c('black', 'red')) +
     labs(
       x = case_when(
         compare == 'means' ~ 'Standardized mean difference',
         compare == 'variance' ~ 'Ratio of variance (log scale)',
         compare == 'covariance' ~ 'Standardized mean difference of interactions (balance of covaraince)'
       ),
       y = NULL,
       color = NULL
     ) +
     theme(legend.position = 'none')

  if (compare == 'variance'){
    p <- p2 +
      coord_cartesian(xlim = c(.25, 4)) +
      scale_x_continuous(trans='log10') +
      labs(
        title = 'Balance',
        subtitle = if_else(
          max(.data$flag) != 0,
          'points represent the treatment group\nred points are beyond 4 times different',
          'points represent the treatment group'
        )
      )
  }else{
    p2 <-p2 + coord_cartesian(xlim = c(-2, 2))
    if(length(unique(.data$type)) > 1){
      p1 <- p1 + facet_wrap(~type)
      p2 <- p2 + facet_wrap(~type)

      p1 <- p1 + labs(
        title = 'Balance',
        subtitle = if_else(
          max(.data$flag) != 0,
          'points represent the treatment group\nred points are beyond 2 standard deviations.',
          'points represent the treatment group'
        )
      )

      p <- p1 + p2

    }else if(unique(.data$type) == 'continuous'){
      p <- p2 + labs(
        title = 'Balance',
        subtitle = if_else(
          max(.data$flag) != 0,
          'points represent the treatment group\nred points are beyond 2 standard deviations.',
          'points represent the treatment group'
        )
      )
    }else{
      p <- p1 + labs(
        title = 'Balance',
        subtitle =  'points represent the treatment group')
    }

  }

  return(p)

}

#' @title Print balance statistics
#' @description See balance statisitics of variables between treatment and control groups.
#'
#' @param .data dataframe
#' @param treatment the column denoted treatment. Must be binary.
#' @param confounders character list of column names denoting the X columns of interest
#' @param estimand character of either ATE, ATT or ATC the causal estimand you are making inferences about
#' @author George Perrett
#'
#' @return tibble
#' @export
#'
#' @import ggplot2 dplyr
#' @importFrom tidyr pivot_longer
#' @importFrom stats var
#'
#' @examples
#' data(lalonde)
#' print_balance(lalonde, 'treat', confounders = c('re78', 'age', 'educ'), estimand = 'ATE')

print_balance <- function(.data, treatment, confounders, estimand = c('ATE', 'ATT', 'ATC')){

  if (length(table(.data[[treatment]])) != 2) stop("treatment must be binary")
  .data[[treatment]] <- coerce_to_logical_(.data[[treatment]])
  estimand <- match.arg(estimand)
  estimand <- toupper(estimand)

  if (estimand %notin% c('ATE', 'ATT', 'ATC')) stop("estimand must be either: ATE, ATT or ATC")

  data_types <-
    apply(.data[, confounders], 2, function(i)
      identical(unique(i)[order(unique(i))], c(0, 1))) %>%
    as.data.frame()

  data_types$name<- rownames(data_types)
  names(data_types)[1] <- 'type'

  data_types <- data_types %>%
    mutate(type = if_else(type, 'binary/categorical', 'continuous')) %>%
    arrange(name)

  table <- .data %>%
    dplyr::select(all_of(c(confounders, treatment))) %>%
    pivot_longer(cols = -treatment) %>%
    group_by(across(c('name', treatment))) %>%
    summarize(mean = mean(value, na.rm = TRUE),
              variance = var(value),
              .groups = 'drop') %>%
    tidyr::pivot_wider(names_from = treatment, values_from = c(variance, mean)) %>%
    dplyr::mutate(
      raw_means = mean_TRUE - mean_FALSE,
      means =
        dplyr::case_when(
          estimand == 'ATE' ~ (mean_TRUE - mean_FALSE) / sqrt((variance_TRUE + variance_FALSE) /2),
          estimand == 'ATT' ~ (mean_TRUE - mean_FALSE) / sqrt(variance_TRUE),
          estimand == 'ATC' ~ (mean_TRUE - mean_FALSE) / sqrt(variance_FALSE)
        ),
           variance = sqrt(variance_TRUE/variance_FALSE)
    ) %>%
    na.omit() %>%
    dplyr::select(name, raw_means, means, variance) %>%
    dplyr::rename(variable = name,
           `difference in means` = raw_means,
           `standardized difference in means` = means,
           `ratio of the variance` = variance) %>%
    dplyr::mutate(across(where(is.numeric), round, 2)) %>%
    dplyr::arrange(variable)

    table$`standardized difference in means` <- as.character(table$`standardized difference in means`)
    table$`ratio of the variance` <- as.character(table$`ratio of the variance`)
    table[data_types$type != 'continuous', 3] <- '--'
    table[data_types$type != 'continuous', 4] <- '--'

    return(table)

}


#' @title Print covariance statistics
#' @description See balance statistics of covariance for specified variables between treatment and control groups.
#'
#' @param .data dataframe
#' @param treatment the column denoted treatment. Must be binary.
#' @param confounders character list of column names denoting the X columns of interest
#' @param estimand character of either ATE, ATT or ATC the causal estimand you are making inferences about
#' @author George Perrett
#'
#' @return tibble
#' @export
#'
#' @import ggplot2 dplyr
#' @importFrom tidyr pivot_longer
#' @importFrom stats var
#'
#' @examples
#' data(lalonde)
#' print_covariance(lalonde, 'treat', confounders = c('re78', 'age', 'educ'), estimand = 'ATE')

print_covariance <- function(.data, treatment, confounders, estimand = c('ATE', 'ATT', 'ATC')){

  if(missing(treatment)) stop('enter a string indicating the name of the treatment variable')
  if (length(table(.data[[treatment]])) != 2) stop("treatment must be binary")
  .data[[treatment]] <- coerce_to_logical_(.data[[treatment]])

  estimand <- match.arg(estimand)
  estimand <- toupper(estimand)
  classes <- sapply(.data[, confounders], class)

  if (estimand %notin% c('ATE', 'ATT', 'ATC')) stop("estimand must be either: ATE, ATT or ATC")
    # gets interactions of all columns
    cov_dat <- combn(.data[, confounders], 2, FUN = Reduce, f = `*`)

    # get column names so we know what is what
    colnames(cov_dat) <- paste(
      combn(names(.data[, confounders]), 2)[1, ],
      combn(names(.data[, confounders]), 2)[2, ],
      sep = '*')
    # add back treatment
    .data <- cbind.data.frame(cov_dat, treatment = .data[[treatment]])

  table <- .data %>%
    pivot_longer(cols = -treatment) %>%
    group_by(across(c('name', treatment))) %>%
    summarize(mean = mean(value, na.rm = TRUE),
              variance = var(value),
              .groups = 'drop') %>%
    pivot_wider(names_from = treatment, values_from = c(variance, mean)) %>%
    mutate(
      raw_means = mean_TRUE - mean_FALSE,
      means =
             case_when(
               estimand == 'ATE' ~ (mean_TRUE - mean_FALSE) / sqrt((variance_TRUE + variance_FALSE) /2),
               estimand == 'ATT' ~ (mean_TRUE - mean_FALSE) / sqrt(variance_TRUE),
               estimand == 'ATC' ~ (mean_TRUE - mean_FALSE) / sqrt(variance_FALSE)
             ),
      variance = sqrt(variance_TRUE/variance_FALSE)
    ) %>%
    na.omit() %>%
    dplyr::select(name, means) %>%
    arrange(desc(abs(means))) %>%
    rename(variable = name,`standardized difference in means` = means) %>%
    mutate(across(where(is.numeric), round, 2))

    return(table)

}


#' @title Plot the covariance
#' @description Visualize balance of the covariance of variables between treatment and control groups. Balance plot reflects balance in standardized units.
#'
#' @param .data dataframe
#' @param treatment the column denoted treatment. Must be binary.
#' @param confounders character list of column names denoting the X columns of interest

#' @author George Perrett
#'
#' @return ggplot object
#' @export
#'
#' @import ggplot2 dplyr
#' @examples
#' data(lalonde)

#' plot_covariance(lalonde, 'treat', c('re75','re74' , 'age', 'educ')) + labs(title = 'My new title')

plot_covariance <- function(.data, treatment, confounders){
  .data$treatment <- .data[[treatment]]
  #.data[, confounders] <- apply(.data[, confounders], 2, function(i) (i - mean(i))/sd(i))
  .data %>%
    GGally::ggpairs(
      upper = list(continuous = "density", combo = "box_no_facet"),
      lower = list(continuous = "points",  combo = "box_no_facet"),
      columns = confounders,
      aes(colour= as.factor(treatment), alpha = .7)) +
    scale_color_manual(values = c('blue', 'red')) +
    scale_fill_manual(values = c('blue', 'red'))
}
priism-center/plotBart documentation built on June 2, 2024, 8:50 a.m.