R/balqual.R

Defines functions show_quality balqual

Documented in balqual

#' @title Evaluate Matching Quality
#'
#' @description The `balqual()` function evaluates the balance quality of a
#'   dataset after matching, comparing it to the original unbalanced dataset. It
#'   computes various summary statistics and provides an easy interpretation
#'   using user-specified cutoff values.

#' @param matched_data An object of class `matched`, generated by the
#'   [match_gps()] function. This object is essential for the `balqual()`
#'   function as it contains the final data.frame and attributes required to
#'   compute the quality coefficients.
#' @param formula A valid R formula used to compute generalized propensity
#'   scores during the first step of the vector matching algorithm in
#'   [estimate_gps()]. This formula must match the one used in `estimate_gps()`.
#' @param type A character vector specifying the quality metrics to calculate.
#'   Can maximally contain 3 values in a vector created by the `c()`. Possible
#'   values include:
#' * `smd` - Calculates standardized mean differences (SMD) between groups,
#'   defined as the difference in means divided by the standard deviation of the
#'   treatment group (Rubin, 2001).
#' * `r` - Computes Pearson's r coefficient using the Z statistic from the
#'   U-Mann-Whitney test.
#' * `var_ratio` - Measures the dispersion differences between groups,
#'   calculated as the ratio of the larger variance to the smaller one.
#' @param statistic A character vector specifying the type of statistics used to
#'   summarize the quality metrics. Since quality metrics are calculated for all
#'   pairwise comparisons between treatment levels, they need to be aggregated
#'   for the entire dataset.
#'   - `max`: Returns the maximum values of the statistics defined in the `type`
#'   argument (as suggested by Lopez and Gutman, 2017).
#'   - `mean`: Returns the corresponding averages.
#'
#'   To compute both, provide both names using the `c()` function.
#' @param cutoffs A numeric vector with the same length as the number of
#'   coefficients specified in the `type` argument. Defines the cutoffs for each
#'   corresponding metric, below which the dataset is considered balanced. If
#'   `NULL`, the default cutoffs are used: 0.1 for `smd` and `r`, and 2 for
#'   `var_ratio`.
#' @param round An integer specifying the number of decimal places to round the
#'   output to.

#'
#' @return If assigned to a name, returns a list of summary statistics of class
#'   `quality` containing:
#'  * `quality_mean` - A data frame with the mean values of the statistics
#'   specified in the `type` argument for all balancing variables used in
#'   `formula`.
#'  * `quality_max` - A data frame with the maximal values of the statistics
#'   specified in the `type` argument for all balancing variables used in
#'   `formula`.
#'  * `perc_matched` - A single numeric value indicating the percentage of
#'   observations in the original dataset that were matched.
#'  * `statistic` - A single string defining which statistic will be displayed
#'   in the console.
#'  * `summary_head` - A summary of the matching process. If `max` is included
#'   in the `statistic`, it contains the maximal observed values for each
#'   variable; otherwise, it includes the mean values.
#'  * `n_before` - The number of observations in the dataset before matching.
#'  * `n_after` - The number of observations in the dataset after matching.
#'  * `count_table` - A contingency table showing the distribution of the
#'   treatment variable before and after matching.
#'
#'   The `balqual()` function also prints a well-formatted table with the
#'   defined summary statistics for each variable in the `formula` to the
#'   console.
#'
#' @examples
#' # We try to balance the treatment variable in the cancer dataset based on age
#' # and sex covariates
#' data(cancer)
#'
#' # Firstly, we define the formula
#' formula_cancer <- formula(status ~ age * sex)
#'
#' # Then we can estimate the generalized propensity scores
#' gps_cancer <- estimate_gps(formula_cancer,
#'   cancer,
#'   method = "multinom",
#'   reference = "control",
#'   verbose_output = TRUE
#' )
#'
#' # ... and drop observations based on the common support region...
#' csr_cancer <- csregion(gps_cancer)
#'
#' # ... to match the samples using `match_gps()`
#' matched_cancer <- match_gps(csr_cancer,
#'   reference = "control",
#'   caliper = 1,
#'   kmeans_cluster = 5,
#'   kmeans_args = list(n.iter = 100),
#'   verbose_output = TRUE
#' )
#'
#' # At the end we can assess the quality of matching using `balqual()`
#' balqual(
#'   matched_data = matched_cancer,
#'   formula = formula_cancer,
#'   type = "smd",
#'   statistic = "max",
#'   round = 3,
#'   cutoffs = 0.2
#' )
#'
#' @seealso [match_gps()] for matching the generalized propensity scores;
#' [estimate_gps()] for the documentation of the `formula` argument.
#'
#' @references Rubin, D.B. Using Propensity Scores to Help Design Observational
#'   Studies: Application to the Tobacco Litigation. Health Services & Outcomes
#'   Research Methodology 2, 169–188 (2001).
#'   https://doi.org/10.1023/A:1020363010465
#'
#'   Michael J. Lopez, Roee Gutman "Estimation of Causal Effects with Multiple
#'   Treatments: A Review and New Ideas," Statistical Science, Statist. Sci.
#'   32(3), 432-454, (August 2017)
#' @export
#'
balqual <- function(matched_data = NULL,
                    formula = NULL,
                    type = c("smd", "r", "var_ratio"),
                    statistic = c("mean", "max"),
                    cutoffs = NULL,
                    round = 3) {
  ############ ARGUMENT CHECKING AND PROCESSING ################################
  # check data
  .chk_cond(
    is.null(matched_data),
    "The argument `matched_data` is missing with no default!"
  )

  .chk_cond(
    !inherits(matched_data, c("data.frame", "matched")),
    "The argument `matched_data` has to be a data.frame of class
            `matched`!"
  )

  ## get rid of other classes
  old_data <- data.frame(attr(matched_data, "original_data"))
  new_data <- data.frame(matched_data)

  # check and process formula
  data_before <- .process_formula(formula, old_data)
  data_after <- .process_formula(formula, new_data)

  # define all posssible pairwise comaparisons of treatment var
  pairwise_comb <- t(utils::combn(unique(data_before[["treat"]]), 2))
  pairwise_comb <- as.data.frame(pairwise_comb)
  pairwise_comb <- pairwise_comb[stats::complete.cases(pairwise_comb), ]
  colnames(pairwise_comb) <- c("group1", "group2")

  # check and process type
  .chk_cond(
    !.check_vecl(type, check_numeric = FALSE),
    "The `type`  argument has to be an atomic vector!"
  )

  type <- .match_discrete_args(type, c("smd", "r", "var_ratio"), "type")

  # check statistic
  .chk_cond(
    !.check_vecl(statistic, check_numeric = FALSE),
    "The `statistic` argument has to be an atomic vector!"
  )

  statistic <- .match_discrete_args(statistic, c("mean", "max"), "statistic")

  # check cutoffs
  if (is.null(cutoffs)) {
    cutoffs <- unlist(lapply(type, function(x) {
      switch(x,
        smd = 0.1,
        r = 0.1,
        var_ratio = 2
      )
    }))
  }

  .chk_cond(
    !.check_vecl(cutoffs, length(type), check_numeric = TRUE),
    "The argument `cutoffs` has to be an numeric vector, with length
            equal to the length of `type` argument."
  )

  # check round
  .check_integer(round, "round")

  ############ PERFORMING THE CALCULATIONS #####################################
  # define output list
  quality_list <- .make_list(nrow(pairwise_comb))

  # loop over all pairwise combinations to calculate the quality metrics
  for (i in seq_len(nrow(pairwise_comb))) {
    # helper function to recode data to 0 (control) and 1 (treatment)
    .binarize_treat <- function(treatment, comb1, comb2) {
      treat_recoded <- rep(NA_real_, length(treatment))

      # assign 0 to controls
      treat_recoded[treatment == pairwise_comb[i, 1]] <- 0

      # and 1 to treatments
      treat_recoded[treatment == pairwise_comb[i, 2]] <- 1

      treat_recoded
    }

    # recode treatment to binary
    treat_before <- .binarize_treat(
      data_before[["treat"]],
      pairwise_comb[i, 1],
      pairwise_comb[i, 2]
    )

    treat_after <- .binarize_treat(
      data_after[["treat"]],
      pairwise_comb[i, 1],
      pairwise_comb[i, 2]
    )

    # define functions to calculate smd, r and %matched
    # smd
    standardized_bias <- function(treatment, covariate) {
      remove_nas <- which(!is.na(treatment))
      covariate <- covariate[remove_nas]
      treatment <- treatment[remove_nas]

      means <- tapply(
        covariate,
        treatment,
        mean
      )
      sd_treat <- stats::sd(covariate[treatment == 1]) # as suggested by Rubin
      (means[1] - means[2]) / sd_treat
    }

    # r
    wilcox_r <- function(treatment, covariate) {
      data_split <- split(covariate, treatment)

      p_value <- stats::wilcox.test(
        x = data_split[["0"]],
        y = data_split[["1"]],
        paired = FALSE
      )$p.value

      z_stat <- stats::qnorm(p_value / 2)

      z_stat / length(data_split[["0"]])
    }

    # variance ratio
    var_ratio <- function(treatment, covariate) {
      remove_nas <- which(!is.na(treatment))
      covariate <- covariate[remove_nas]
      treatment <- treatment[remove_nas]

      sd_vals <- tapply(
        covariate,
        treatment,
        stats::sd
      )

      max(sd_vals) / min(sd_vals)
    }

    # calculating the statistics
    smd_before <- apply(data_before[["model_covs"]], 2, function(x) {
      standardized_bias(treat_before, x)
    })

    smd_after <- apply(data_after[["model_covs"]], 2, function(x) {
      standardized_bias(treat_after, x)
    })

    r_effsize_before <- apply(data_before[["model_covs"]], 2, function(x) {
      wilcox_r(treat_before, x)
    })

    r_effsize_after <- apply(data_after[["model_covs"]], 2, function(x) {
      wilcox_r(treat_after, x)
    })

    variance_before <- apply(data_before[["model_covs"]], 2, function(x) {
      var_ratio(treat_before, x)
    })

    variance_after <- apply(data_after[["model_covs"]], 2, function(x) {
      var_ratio(treat_after, x)
    })

    variables <- rbind(
      smd_before,
      smd_after,
      r_effsize_before,
      r_effsize_after,
      variance_before,
      variance_after
    )

    variables <- abs(variables)

    # convert to characters for duplicates check
    var_uniques <- data.frame(lapply(t(variables), function(x) {
      as.character(round(x, 8))
    }))

    var_uniques <- duplicated(var_uniques)

    # subset the vars by unique values
    variables <- variables[, !var_uniques, drop = FALSE]
    colnames(variables) <- colnames(data_before[["model_covs"]])[!var_uniques]

    # assign to list
    quality_list[[i]] <- as.data.frame(variables)
  }

  # output data.frames
  quality_dataframe <- data.frame(
    coef_name = rep(c("smd", "r", "var_ratio"),
      each = 2
    ),
    time = rep(c("before", "after"), 3)
  )

  # means
  quality_mean <- create_balqual_output(quality_list,
    quality_dataframe,
    operation = "+",
    round = round,
    which_coefs = type,
    cutoffs = cutoffs
  )

  # maxes
  quality_max <- create_balqual_output(quality_list,
    quality_dataframe,
    operation = "max",
    round = round,
    which_coefs = type,
    cutoffs = cutoffs
  )

  # % Matched
  perc_matched <- length(data_after[["treat"]]) /
    length(data_before[["treat"]]) * 100

  perc_matched <- round(perc_matched, 2)

  # calculating total maximas or means
  type_recoded <- lapply(type, function(x) {
    switch(x,
      "smd" = "SMD",
      "r" = "r",
      "var_ratio" = "Var"
    )
  })

  # switch to use different dataframes based on statistics
  summary_head <- lapply(type_recoded, function(x) {
    if ("max" %in% statistic) {
      max(quality_max$After[quality_max$Coef == x])
    } else {
      max(quality_mean$After[quality_mean$Coef == x])
    }
  })

  summary_head <- unlist(summary_head)

  names(summary_head) <- unlist(type_recoded)

  # count table for treatment variable
  times <- c(
    rep("Before", length(data_before[["treat"]])),
    rep("After", length(data_after[["treat"]]))
  )

  count_data <- data.frame(
    Matching = times,
    Treatment = c(
      data_before[["treat"]],
      data_after[["treat"]]
    )
  )

  count_table <- with(count_data, table(Treatment, Matching))

  count_table <- cbind(Treatment = rownames(count_table), count_table)
  count_table <- count_table[, c("Treatment", "Before", "After")]
  rownames(count_table) <- NULL

  # defining final object for printing
  quality_list_print <- list(
    quality_mean = quality_mean,
    quality_max = quality_max,
    perc_matched = perc_matched,
    statistic = statistic,
    summary_head = summary_head,
    n_before = length(data_before[["treat"]]),
    n_after = length(data_after[["treat"]]),
    count_table = count_table
  )

  ## Setting a new class for the results
  quality_list_print <- structure(quality_list_print,
    class = "quality"
  )

  ## print custom output
  show_quality(quality_list_print)

  ## return output list
  return(invisible(quality_list_print))
}

# Function to show the contents of an object of class 'csres'
show_quality <- function(object) {
  # Helper function to print a table
  print_table <- function(df, colnames_df, ncol) {
    # Define separator based on the number of columns
    separator <- paste(rep("-", ifelse(ncol == 3, 50, 80)), collapse = "")

    # Define column formats based on the number of columns
    col_format <- if (ncol == 3) {
      c("%-25s", "%-10s", "%-10s")
    } else {
      c("%-25s", "%-5s", "%-12s", "%-12s", "%-12s")
    }

    # Print table header
    cat(separator, "\n")
    cat(paste(
      sprintf(col_format[1], colnames_df[1]), " | ",
      sprintf(col_format[2], colnames_df[2]), " | ",
      if (ncol == 3) sprintf(col_format[3], colnames_df[3]),
      if (ncol == 5) paste0(sprintf(col_format[3], colnames_df[3]), " | "),
      if (ncol == 5) paste0(sprintf(col_format[4], colnames_df[4]), " | "),
      if (ncol == 5) sprintf(col_format[5], colnames_df[5]),
      sep = ""
    ), "\n")
    cat(separator, "\n")

    # Print each row of the data frame
    apply(df, 1, function(row) {
      cat(paste(
        sprintf(col_format[1], row[1]), " | ",
        sprintf(col_format[2], row[2]), " | ",
        if (ncol == 3) sprintf(col_format[3], row[3]),
        if (ncol == 5) paste0(sprintf(col_format[3], row[3]), " | "),
        if (ncol == 5) paste0(sprintf(col_format[4], row[4]), " | "),
        if (ncol == 5) sprintf(col_format[5], row[5]),
        sep = ""
      ), "\n")
    })
    cat(separator, "\n")
  }

  # Function to print quality statistics tables (mean and max)
  print_quality_table <- function(table, label) {
    cat(label, ":\n")
    colnames_main <- c("Variable", "Coef", "Before", "After", "Quality")
    print_table(table,
      colnames_df = colnames_main, ncol = 5
    )
    cat("\n")
  }

  # Header for the output
  cat("\nMatching Quality Evaluation\n")
  cat(paste(rep("=", 80), collapse = ""), "\n\n")

  # Print the count table for the treatment variable (3 columns)
  cat("Count table for the treatment variable:\n")
  print_table(object$count_table,
    colnames_df = colnames(object$count_table),
    ncol = 3
  )

  # Matching summary statistics
  cat("\n\nMatching summary statistics:\n")
  cat(paste(rep("-", 40), collapse = ""), "\n")
  cat("Total n before matching:\t", format(object$n_before, nsmall = 0), "\n")
  cat("Total n after matching:\t\t", format(object$n_after, nsmall = 0), "\n")
  cat(
    "% of matched observations:\t",
    format(object$perc_matched, nsmall = 2), "%\n"
  )

  # Print summary_head (maximal values for each variable)
  for (i in seq_along(object$summary_head)) {
    sum_text <- "maximal"
    tab_after <- ifelse(
      names(
        object$summary_head[i]
      ) == "r" && "max" %nin% object$statistic,
      "\t\t",
      "\t"
    )
    cat(
      "Total ", sum_text, " ", names(object$summary_head[i]), "value:",
      tab_after, object$summary_head[i], "\n"
    )
  }

  cat("\n\n")

  # Print mean and/or max quality tables based on the statistic
  if (length(object$statistic) == 1) {
    if (object$statistic == "mean") {
      print_quality_table(object$quality_mean, "Mean values")
    } else {
      print_quality_table(object$quality_max, "Maximal values")
    }
  } else {
    print_quality_table(object$quality_mean, "Mean values")
    cat("\n")
    print_quality_table(object$quality_max, "Maximal values")
  }
}

Try the vecmatch package in your browser

Any scripts or data that you put into this service are public.

vecmatch documentation built on April 3, 2025, 8:46 p.m.