R/class_funs.R

Defines functions plot.ame print.summary.ame summary.ame print.ame get_average_effects n_covs_matched_on

Documented in plot.ame print.ame print.summary.ame summary.ame

n_covs_matched_on <- function(x, unit, cov_inds) {
  # Faster than for loop for moderate p
  if (is.null(x$MGs[[unit]])) {
    return(0)
  }

  data <- data.matrix(x$data[x$MGs[[unit]], cov_inds])

  # na.rm takes care of `missing_data = 'keep'`
  return(sum(colSums(sweep(data, 2, data[1, ]) ^ 2) == 0, na.rm = TRUE))
}

get_average_effects <- function(x) {

  unmatched <- !x$data$matched
  new_ind <- seq_along(unmatched) - cumsum(unmatched)

  x$MGs <- x$MGs[!unmatched]
  x$MGs <- lapply(x$MGs, function(z) new_ind[z])

  n <- length(x$MGs)
  Tr <- x$data[[x$info$treatment]][!unmatched]
  Y <- x$data[[x$info$outcome]][!unmatched]

  K <- numeric(n)

  for (i in 1:n) {
    Tr_i <- Tr[i]

    for (j in 1:n) {
      if (Tr[j] == Tr_i) {
        next
      }
      MG <- x$MGs[[j]]
      if (!(i %in% MG)) {
        next
      }
      opp_sign <- sum(Tr[MG] == Tr_i)
      K[i] <- K[i] + 1 / opp_sign
    }
  }

  n1 <- sum(Tr == 1)
  n0 <- sum(Tr == 0)

  ATE <- sum((2 * Tr - 1) * (1 + K) * Y) / n
  ATT <- sum((Tr - (1 - Tr) * K) * Y) / n1
  ATC <- sum((Tr * K - (1 - Tr)) * Y) / n0

  cond_var <- 0
  cond_var_t <- 0
  cond_var_c <- 0

  for (i in 1:n) {
    cond_var_tmp <- 0
    cond_var_tmp_t <- 0
    cond_var_tmp_c <- 0

    Tr_i <- Tr[i]
    Y_i <- Y[i]

    MG <- x$MGs[[i]]
    MG <- MG[Tr[MG] != Tr_i]

    if (Tr_i == 1) {
      cond_var <- cond_var + mean((Y_i - Y[MG] - ATE) ^ 2)
      cond_var_t <- cond_var_t + mean((Y_i - Y[MG] - ATT) ^ 2)
    }
    else {
      cond_var <- cond_var + mean((Y[MG] - Y_i - ATE) ^ 2)
      cond_var_c <- cond_var_c + mean((Y[MG] - Y_i - ATT) ^ 2)
    }
  }

  cond_var <- cond_var / (2 * n)
  cond_var_t <- cond_var_t / (2 * n1)
  cond_var_c <- cond_var_c / (2 * n0)

  V_sample_sate <- sum(cond_var * (1 + K) ^ 2) / n ^ 2
  V_sample_satt <- sum(cond_var_t * (Tr - (1 - Tr) * K) ^ 2) / n1 ^ 2
  V_sample_satc <- sum(cond_var_c * (Tr * K - (1 - Tr)) ^ 2) / n0 ^ 2

  return(matrix(c(ATE, ATT, ATC, V_sample_sate, V_sample_satt, V_sample_satc),
                ncol = 2,
                dimnames = list(c('All', 'Treated', 'Control'),
                                c('Mean', 'Variance'))))
}

#' @param x An object of class \code{ame}, returned by a call to
#'   \code{\link{FLAME}} or \code{\link{DAME}}.
#' @param digits Number of significant digits for printing the average treatment
#' effect.
#' @param linewidth Maximum number of characters on line; output will be wrapped
#' accordingly.
#' @param ... Additional arguments to be passed to other methods.
#' @rdname AME
#' @export
print.ame <- function(x, digits = getOption("digits"), linewidth = 80, ...) {

  df <- x$data

  algo <- x$info$algo

  outcome_type <- x$info$outcome_type

  replacement <- x$info$replacement

  n_iters <- length(x$cov_sets)

  n_matched <- sum(df$matched)
  n_total <- nrow(df)

  indentation <- 2
  cat('An object of class `ame`:\n')
  # cat(rep(' ', indentation))
  cat(strwrap(paste(' ', algo, 'ran for', n_iters, 'iterations, matching',
                    n_matched, 'out of', n_total, 'units',
                    ifelse(replacement, 'with', 'without'), 'replacement.'),
              width = linewidth, indent = indentation, exdent = indentation),
      sep = '\n ')

  if (outcome_type == 'continuous') {
    cat(strwrap(paste0('  The average treatment effect of treatment `',
                       x$info$treatment, '` on outcome `', x$info$outcome,
                       '` is estimated to be ',
                       round(ifelse(x$info$estimate_CATEs &&
                                      outcome_type == 'continuous',
                                    mean(x$data$CATE, na.rm = TRUE),
                                    get_average_effects(x)['All', 'Mean']),
                             digits = digits),
                       '.'),
                width = linewidth, indent = indentation, exdent = indentation),
        sep = '\n ')
  }

  if (x$info$missing_data == 'drop') {
    missing_data_message <-
      'Units with missingness in the matching data were not matched.'
  }
  else if (x$info$missing_data == 'impute') {
    missing_data_message <-
      'Missing values in the matching data were imputed by MICE.'
  }
  else if (x$info$missing_data == 'keep') {
    missing_data_message <-
      'Missing values in the matching data were not matched on.'
  }
  else if (x$info$missing_data == 'none') {
    missing_data_message <- NULL
  }

  if (x$info$missing_holdout == 'drop') {
    missing_holdout_message <-
      'Units with missingness in the holdout data were not used to compute PE.'
  }
  else if (x$info$missing_holdout == 'impute') {
    missing_holdout_message <-
      'Missing values in the holdout data were imputed by MICE.'
  }
  else if (x$info$missing_holdout == 'none') {
    missing_holdout_message <- NULL
  }

  if (!is.null(missing_data_message)) {
    cat(strwrap(missing_data_message,
                width = linewidth, indent = indentation, exdent = indentation),
        sep = '\n ')
  }
  if (!is.null(missing_holdout_message)) {
    cat(strwrap(missing_holdout_message,
                width = linewidth, indent = indentation, exdent = indentation),
        sep = '\n ')
  }

  return(invisible(x))
}

#' Summarize the output of FLAME or DAME
#'
#' These methods create and print objects of class \code{summary.ame} containing
#' information on the numbers of units matched by the AME algorithm, matched
#' groups formed, and, if applicable, average treatment effects.
#'
#' The average treatment effect (ATE) is estimated as the average CATE estimate
#' across all matched units in the data, while the average treatment effect on
#' the treated (ATT) and average treatment effect on controls (ATC) average only
#' across matched treated or matched control units, respectively. Variances of
#' these estimates are computed as in Abadie, Drukker, Herr, and Imbens (The
#' Stata Journal, 2004) assuming constant treatment effect and homoscedasticity.
#' Note that the implemented estimator is \strong{not} =asymptotically normal
#' and so in particular, asymptotically valid confidence intervals or hypothesis
#' tests cannot be conducted on its basis. In the future, the estimation
#' procedure will be changed to employ the nonparametric regression bias
#' adjustment estimator of Abadie and Imbens 2011 which is asymptotically
#' normal.
#'
#' @return A list of type \code{summary.ame} with the following entries:
#' \describe{
#' \item{MG}{
#'   A list with the number and median size of matched groups formed.
#'   Additionally, two of the highest quality matched groups formed. Quality
#'   is determined first by number of covariates matched on and second by
#'   matched group size.
#'  }
#' \item{n_matches}{
#'   A matrix detailing the number of treated and control units matched.
#'  }
#' \item{TEs}{
#'   If the matching data had a continuous outcome, estimates of the ATE, ATT,
#'   and ATC and the corresponding variance of the estimates.
#'  }
#' }

#' @name summary.ame
NULL
#> NULL

#' @param object An object of class \code{ame}, returned by a call to
#'   \code{\link{FLAME}} or \code{\link{DAME}}
#' @param ... Additional arguments to be passed on to other methods
#' @rdname summary.ame
#' @export
summary.ame <- function(object, ...) {
  df <- object$data
  matched_df <- droplevels(df[df$matched, ]) # anywhere else we need to do this?

  outcome_name <- object$info$outcome
  treated_name <- object$info$treatment

  cov_inds <-
    which(!(colnames(df) %in%
              c(outcome_name, treated_name,
                'matched', 'weight', 'MG', 'CATE')))

  n_matched_on <-
    vapply(seq_len(nrow(df)),
           function(z) n_covs_matched_on(object, z, cov_inds),
           numeric(1))


  if (all(df$matched)) {
    n_match_mat <- matrix(c(sum(df[[treated_name]] == 0),
                            sum(df[[treated_name]] == 1),
                            sum(df[[treated_name]] == 0),
                            sum(df[[treated_name]] == 1),
                            0, 0),
                          byrow = TRUE,
                          ncol = 2,
                          dimnames = list(c('All', 'Matched', 'Unmatched'),
                                          c('Control', 'Treated')))
  }
  else {
    n_matches <- c(table(df[[treated_name]], df$matched))

    n_total <- c(sum(n_matches[c(1, 3)]), sum(n_matches[c(2, 4)]))

    n_match_mat <- matrix(c(n_total, n_matches[3:4], n_matches[1:2]),
                          byrow = TRUE,
                          ncol = 2,
                          dimnames = list(c('All', 'Matched', 'Unmatched'),
                                          c('Control', 'Treated')))
  }

  if (object$info$outcome_type == 'continuous') {
    average_effects <- get_average_effects(object)
  }

  cov_treatment_outcome_inds <-
    !(colnames(df) %in% c('matched', 'weight', 'MG', 'CATE'))

  max_matched_on <- max(n_matched_on)

  matched_on_max <- which(n_matched_on == max_matched_on)
  MG_matched_on_max <- object$MGs[matched_on_max]
  MG_matched_on_max <- MG_matched_on_max[!duplicated(MG_matched_on_max)]
  if (length(MG_matched_on_max) == 1) {
    highest_quality <- MG_matched_on_max[[1]][1]
  }
  else {
    units_matched_on_max <-
      vapply(MG_matched_on_max[!duplicated(MG_matched_on_max)],
             `[`,
             FUN.VALUE = numeric(1),
             1)

    MG_matched_on_max_sizes <- vapply(MG_matched_on_max, length, numeric(1))

    sorted_quality <-
      sort(MG_matched_on_max_sizes, decreasing = TRUE, index.return = TRUE)

    highest_quality <-
      rownames(object$data)[units_matched_on_max[sorted_quality$ix[1:2]]]
  }

  object$MGs <- object$MGs[!duplicated(object$MGs)]
  MG_size <- vapply(object$MGs, length, FUN.VALUE = numeric(1))
  MG_number <- sum(MG_size > 0)
  MG_median_size <- median(MG_size[MG_size > 0])

  summary_obj <- list(MG = list(`number` = MG_number,
                                `median_size` = MG_median_size,
                                `highest_quality` = highest_quality),
                      n_matches = n_match_mat)
  if (object$info$outcome_type == 'continuous') {
    summary_obj <- c(summary_obj, list(TEs = average_effects))
  }

  class(summary_obj) <- 'summary.ame'
  return(summary_obj)
}

#' Print a summary of FLAME or DAME
#'
#' @param x An object of class \code{summary.ame}, returned by a call to
#'   \code{\link{summary.ame}}
#' @param digits Number of significant digits for printing the average treatment
#' effect estimates and their variances.
#' @param ... Additional arguments to be passed on to other methods.
#' @rdname summary.ame
#' @export
print.summary.ame <- function(x, digits = 3, ...) {

  max_meanlen <- 7
  max_varlen <- 8
  lablen <- 13

  if ('TEs' %in% names(x)) {

    ATE_meanstr <- format(x$TEs['All', 1], digits = digits, justify = 'right')
    ATE_varstr <- format(x$TEs['All', 2], digits = digits, justify = 'right')

    ATT_meanstr <- format(x$TEs['Treated', 1], digits = digits, justify='right')
    ATT_varstr <- format(x$TEs['Treated', 2], digits = digits, justify='right')
    ATC_meanstr <- format(x$TEs['Control', 1], digits = digits, justify='right')
    ATC_varstr <- format(x$TEs['Control', 2], digits = digits, justify='right')

    max_meanlen <- max(max_meanlen,
                       nchar(c(ATE_meanstr, ATT_meanstr, ATC_meanstr)))
    max_varlen <- max(max_varlen, nchar(c(ATE_varstr, ATT_varstr, ATC_varstr)))
  }

  total_line_len <- sum(max_meanlen, max_varlen, 2, lablen)

  cat('Number of Units:\n')

  cat(format('', width = lablen),
      format('Control', width = max_meanlen, justify = 'right'),
      format('Treated', width = max_varlen, justify = 'right'))
  cat('\n')

  cat(format('  All', width = lablen, justify = 'left'),
      format(x$n_matches['All', 'Control'],
             width = max_meanlen, justify = 'right'),
      format(x$n_matches['All', 'Treated'],
             width = max_varlen, justify = 'right'),
      '\n')

  cat(format('  Matched', width = lablen, justify = 'left'),
      format(x$n_matches['Matched', 'Control'],
             width = max_meanlen, justify = 'right'),
      format(x$n_matches['Matched', 'Treated'],
             width = max_varlen, justify = 'right'),
      '\n')

  cat(format('  Unmatched', width = lablen, justify = 'left'),
      format(x$n_matches['Unmatched', 'Control'],
             width = max_meanlen, justify = 'right'),
      format(x$n_matches['Unmatched', 'Treated'],
             width = max_varlen, justify = 'right'),
      '\n')

  if ('TEs' %in% names(x)) {

    cat('\nAverage Treatment Effects:\n')
    cat(format('', width = lablen),
        format('Mean', width = max_meanlen, justify = 'right'),
        format('Variance', width = max_varlen, justify = 'right'),
        '\n')

    cat(format('  All', width = lablen),
        format(x$TEs['All', 1], digits = digits,
               width = max_meanlen, justify = 'right'),
        format(x$TEs['All', 2], digits = digits,
               width = max_varlen, justify = 'right'),
        '\n')

    cat(format('  Treated', width = lablen),
        format(x$TEs['Treated', 1],
               digits = digits, width = max_meanlen, justify = 'right'),
        format(x$TEs['Treated', 2],
               digits = digits, width = max_varlen, justify = 'right'),
        '\n')

    cat(format('  Control', width = lablen),
        format(x$TEs['Control', 1],
               digits = digits, width = max_meanlen, justify = 'right'),
        format(x$TEs['Control', 2],
               digits = digits, width = max_varlen, justify = 'right'),
        '\n')
  }

  cat('\nMatched Groups:\n')
  cat(format('  Number', width = lablen),
      format(x$MG$number, width = total_line_len - lablen, justify = 'right'),
      '\n')

  cat(format('  Median size', width = lablen),
      format(x$MG$median_size,
             width = total_line_len - lablen, justify = 'right'),
      '\n')
  cat('  Highest quality:', format(ifelse(length(x$MG$highest_quality) == 1,
                                   as.character(x$MG$highest_quality),
                                   paste(x$MG$highest_quality,
                                         collapse =  ' and ')),
      width = total_line_len - 18, justify = 'right'))
  cat('\n')
  return(invisible(x))
}

#' Plot a summary of FLAME or DAME
#'
#' Plot information about numbers of covariates matched on, CATE estimates, and
#' covariate set dropping order after a call to \code{FLAME} or \code{DAME}.
#'
#' \code{plot.ame} displays four plots by default. The first contains
#' information on the number of covariates that matched groups were formed on,
#' and thereby gives some indication of the quality of matched groups across the
#' matched data. The second plots matched group sizes against CATEs, which can
#' be useful for determining whether higher quality matched groups yield
#' different treatment effect estimates than lower quality ones. The third plots
#' a density estimate of the estimated CATE distribution. The fourth displays a
#' heatmap showing which covariates were dropped (shown in black) throughout the
#' matching procedure.
#'
#' @param x An object of class \code{ame}, returned by a call to
#'   \code{\link{FLAME}} or \code{link{DAME}}.
#' @param which_plots A vector describing which plots should be displayed. See
#' details.
#' @param ... Additional arguments to passed on to other methods.
#' @export
plot.ame <- function(x, which_plots = c(1, 2, 3, 4), ...) {

  if (min(which_plots) <= 0 | max(which_plots) >= 5) {
    stop('Please supply an integer 1 through 4 for `which_plots`.')
  }

  # Worry about memory? Should we always do x$data[x$data$matched, ]?
  df <- x$data
  df <- df[df$matched, ]

  n_plots <- length(which_plots)
  n_plotted <- 0
  first_plot <- min(which_plots)

  outcome_name <- x$info$outcome
  treated_name <- x$info$treatment

  Y <- x$data[[outcome_name]]
  Tr <- x$data[[treated_name]]

  cov_inds <-
    !(colnames(df) %in% c(outcome_name, treated_name,
                          'matched', 'weight', 'MG', 'CATE'))

  MGs <- x$MGs

  n_MGs <- length(MGs)
  MG_size <- vapply(x$MGs, length, FUN.VALUE = numeric(1))


  if (!x$info$estimate_CATEs) {
    CATEs <- numeric(n_MGs)
    for (i in seq_len(nrow(x$data))) {
      MG <- MGs[[i]]
      if (is.null(MG)) {
        CATEs[i] <- NA
        next
      }
      if (Tr[i] == 1) {
        CATEs[i] <- Y[i] - mean(Y[MG[Tr[MG] == 0]])
      }
      else {
        CATEs[i] <- mean(Y[MG[Tr[MG] == 1]] - Y[i])
      }
    }
  }
  else {
    CATEs <- x$data$CATE
  }

  MG_size <- MG_size[x$data$matched]
  CATEs <- CATEs[x$data$matched]
  ATE <- mean(CATEs)

  n_covs_matched <-
    vapply(seq_len(nrow(df)),
           function(z) n_covs_matched_on(x, z, cov_inds),
           numeric(1))

  # Number of covariates matched on
  if (1 %in% which_plots) {
    barplot(table(n_covs_matched[n_covs_matched > 0]),
            xlab = 'Number of Covariates Matched On',
            ylab = 'Number of Units')
    n_plotted <- n_plotted + 1
  }

  if (n_plotted == n_plots) {
    return(invisible(x))
  }

  # MG Size vs. CATE
  if (2 %in% which_plots) {
    if (interactive() & first_plot != 2) {
      readline(prompt="Press <enter> to view next plot")
    }
    plot(MG_size, CATEs,
         xlab = 'Size of Matched Group',
         ylab = 'Estimated Conditional Average Treatment Effect')

    if (min(CATEs) < 0 && 0 < max(CATEs)) {
      include_null <- TRUE
    }
    else {
      include_null <- FALSE
    }


    abline(h = ATE, lty = 2, col = 'blue')
    if (include_null) {
      abline(h = 0, lty = 3, col = 'red')
    }

    if (include_null) {
      legend('topright',
             legend = c('Estimated ATE', 'Null Effect'),
             lty = c(2, 3),
             col = c('blue', 'red'))
    }
    else {
      legend('topright', legend = c('Estimated ATE'), lty = 2, col = 'blue')
    }

    n_plotted <- n_plotted + 1
  }

  if (n_plotted == n_plots) {
    return(invisible(x))
  }

  # CATE Density
  if (3 %in% which_plots) {
    if (interactive() & first_plot != 3) {
      readline(prompt="Press <enter> to view next plot")
    }
    dens <- density(CATEs, na.rm = TRUE)
    if (0 > min(dens$x) && 0 < max(dens$x)) {
      include_null <- TRUE
    }
    else {
      include_null <- FALSE
    }

    plot(dens,
         xlab = c('Estimated Conditional Average Treatment Effect'),
         ylab = '', main = '',
         zero.line = FALSE)

    abline(v = ATE, lty = 2, col = 'blue')
    if (include_null) {
      abline(v = 0, lty = 3, col = 'red')
    }

    if (include_null) {
      legend('topright',
             legend = c('Estimated ATE', 'Null Effect'),
             lty = c(2, 3), col = c('blue', 'red'))
    }
    else {
      legend('topright', legend = c('Estimated ATE'), lty = 2, col = 'blue')
    }

    n_plotted <- n_plotted + 1
  }

  if (n_plotted == n_plots) {
    return(invisible(x))
  }

  # Dropped covariate sets
  if (4 %in% which_plots) {
    if (interactive() & first_plot != 4) {
      readline(prompt="Press <enter> to view next plot")
    }
    cov_sets <- x$cov_sets
    covs_dropped <- matrix(1, nrow = sum(cov_inds), ncol = length(x$cov_sets),
                           dimnames = list(colnames(df)[cov_inds],
                                           seq_along(x$cov_sets)))
    for (i in seq_along(x$cov_sets)) {
      covs_dropped[x$cov_sets[[i]], i] <- 0
    }

    # Thanks to: https://stackoverflow.com/questions/5506046

    new_margin <- par(mar = c(5, 7, 4, 2) + 0.1)

    image(z = t(covs_dropped), col = c('black', 'white'),
          xaxt = 'n', yaxt = 'n',
          xlab = 'Iteration', cex.lab = 1.2)

    axis(side = 1, at = seq(0, 1, length.out = ncol(covs_dropped)),
         labels = seq_len(ncol(covs_dropped)))
    axis(side = 2, at = seq(0, 1, length.out = nrow(covs_dropped)),
         labels = rownames(covs_dropped), las = 1)

    title(ylab = 'Variables Dropped', line = 6, cex.lab = 1.2)

    par(new_margin)

    return(invisible(x))
  }
}

Try the FLAME package in your browser

Any scripts or data that you put into this service are public.

FLAME documentation built on Dec. 11, 2021, 9:26 a.m.