Nothing
# ==============================================================================
# Balance Diagnostics for Matching Results
# ==============================================================================
# Functions to assess the quality of matches by comparing distributions
# of matching variables between left and right units in matched pairs
#' Calculate Standardized Difference
#'
#' Computes the standardized mean difference between two groups.
#' This is a key metric for assessing balance in matched samples.
#'
#' @param x1 Numeric vector for group 1
#' @param x2 Numeric vector for group 2
#' @param pooled Logical, if TRUE use pooled standard deviation (default),
#' if FALSE use group 1 standard deviation
#'
#' @return Numeric value representing the standardized difference
#'
#' @details
#' Standardized difference = (mean1 - mean2) / pooled_sd
#' where pooled_sd = sqrt((sd1^2 + sd2^2) / 2)
#'
#' Common thresholds: less than 0.1 is excellent balance, 0.1-0.25 is good
#' balance, 0.25-0.5 is acceptable balance, and greater than 0.5 is poor balance.
#'
#' @keywords internal
standardized_difference <- function(x1, x2, pooled = TRUE) {
# Remove NA values
x1 <- x1[!is.na(x1)]
x2 <- x2[!is.na(x2)]
# Handle edge cases
if (length(x1) == 0 || length(x2) == 0) {
return(NA_real_)
}
mean1 <- mean(x1)
mean2 <- mean(x2)
mean_diff <- mean1 - mean2
if (pooled) {
# Pooled standard deviation
sd1 <- stats::sd(x1)
sd2 <- stats::sd(x2)
pooled_sd <- sqrt((sd1^2 + sd2^2) / 2)
} else {
# Use group 1 SD
pooled_sd <- stats::sd(x1)
}
# Avoid division by zero
if (is.na(pooled_sd) || pooled_sd == 0) {
return(0)
}
return(mean_diff / pooled_sd)
}
#' Calculate Variable-Level Balance Statistics
#'
#' @param left_vals Numeric vector of values from left group
#' @param right_vals Numeric vector of values from right group
#' @param var_name Character, name of the variable
#'
#' @return List with balance statistics for this variable
#'
#' @keywords internal
calculate_var_balance <- function(left_vals, right_vals, var_name) {
# Remove NAs
left_clean <- left_vals[!is.na(left_vals)]
right_clean <- right_vals[!is.na(right_vals)]
# Basic statistics
mean_left <- mean(left_clean)
mean_right <- mean(right_clean)
mean_diff <- mean_left - mean_right
sd_left <- stats::sd(left_clean)
sd_right <- stats::sd(right_clean)
# Standardized difference
std_diff <- standardized_difference(left_clean, right_clean, pooled = TRUE)
# Variance ratio
var_ratio <- if (!is.na(sd_right) && sd_right > 0) sd_left / sd_right else NA_real_
# Kolmogorov-Smirnov test for distribution similarity
ks_result <- tryCatch({
stats::ks.test(left_clean, right_clean)
}, error = function(e) NULL)
ks_statistic <- if (!is.null(ks_result)) ks_result$statistic else NA_real_
ks_pvalue <- if (!is.null(ks_result)) ks_result$p.value else NA_real_
# Return statistics
list(
variable = var_name,
mean_left = mean_left,
mean_right = mean_right,
mean_diff = mean_diff,
sd_left = sd_left,
sd_right = sd_right,
std_diff = std_diff,
var_ratio = var_ratio,
ks_statistic = ks_statistic,
ks_pvalue = ks_pvalue,
n_left = length(left_clean),
n_right = length(right_clean)
)
}
#' Balance Diagnostics for Matched Pairs
#'
#' Computes comprehensive balance statistics comparing the distribution of
#' matching variables between left and right units in the matched sample.
#'
#' @param result A matching result object from \code{match_couples()} or
#' \code{greedy_couples()}
#' @param left Data frame of left units
#' @param right Data frame of right units
#' @param vars Character vector of variable names to check balance for.
#' Defaults to the variables used in matching (if available in result).
#' @param left_id Character, name of ID column in left data (default: "id")
#' @param right_id Character, name of ID column in right data (default: "id")
#'
#' @return An S3 object of class \code{balance_diagnostics} containing:
#' \describe{
#' \item{var_stats}{Tibble with per-variable balance statistics}
#' \item{overall}{List with overall balance metrics}
#' \item{pairs}{Tibble of matched pairs with variables}
#' \item{n_matched}{Number of matched pairs}
#' \item{n_unmatched_left}{Number of unmatched left units}
#' \item{n_unmatched_right}{Number of unmatched right units}
#' \item{method}{Matching method used}
#' \item{has_blocks}{Whether blocking was used}
#' \item{block_stats}{Per-block statistics (if blocking used)}
#' }
#'
#' @details
#' This function computes several balance metrics:
#'
#' Standardized Difference: The difference in means divided by the pooled
#' standard deviation. Values less than 0.1 indicate excellent balance,
#' 0.1-0.25 good balance.
#'
#' Variance Ratio: The ratio of standard deviations (left/right).
#' Values close to 1 are ideal.
#'
#' KS Statistic: Kolmogorov-Smirnov test statistic comparing distributions.
#' Lower values indicate more similar distributions.
#'
#' Overall Metrics include mean absolute standardized difference across
#' all variables, proportion of variables with large imbalance
#' (|std diff| > 0.25), and maximum standardized difference.
#'
#' @examples
#' # Create sample data
#' set.seed(123)
#' left <- data.frame(
#' id = 1:10,
#' age = rnorm(10, 45, 10),
#' income = rnorm(10, 50000, 15000)
#' )
#' right <- data.frame(
#' id = 11:30,
#' age = rnorm(20, 47, 10),
#' income = rnorm(20, 52000, 15000)
#' )
#'
#' # Match
#' result <- match_couples(left, right, vars = c("age", "income"))
#'
#' # Get balance diagnostics
#' balance <- balance_diagnostics(result, left, right, vars = c("age", "income"))
#' print(balance)
#'
#' # Get balance table
#' balance_table(balance)
#'
#' @export
balance_diagnostics <- function(result,
left,
right,
vars = NULL,
left_id = "id",
right_id = "id") {
# Validate inputs
if (!inherits(result, "matching_result")) {
stop("result must be a matching_result object from match_couples() or greedy_couples()")
}
if (!left_id %in% names(left)) {
stop("left_id '", left_id, "' not found in left data")
}
if (!right_id %in% names(right)) {
stop("right_id '", right_id, "' not found in right data")
}
# If vars not specified, try to get from result
if (is.null(vars)) {
if (!is.null(result$info$vars)) {
vars <- result$info$vars
} else {
stop("vars must be specified (could not infer from result)")
}
}
# Validate vars exist
missing_left <- setdiff(vars, names(left))
missing_right <- setdiff(vars, names(right))
if (length(missing_left) > 0) {
stop("Variables not found in left: ", paste(missing_left, collapse = ", "))
}
if (length(missing_right) > 0) {
stop("Variables not found in right: ", paste(missing_right, collapse = ", "))
}
# Extract matched pairs
pairs <- result$pairs
# Filter to only matched pairs (right_id > 0)
matched_pairs <- pairs[pairs[[2]] > 0, ]
n_matched <- nrow(matched_pairs)
# Count unmatched
n_unmatched_left <- sum(pairs[[2]] == 0)
n_total_left <- nrow(left)
n_total_right <- nrow(right)
n_unmatched_right <- n_total_right - n_matched
# Join variables for matched pairs
left_matched <- merge(
matched_pairs[, 1, drop = FALSE],
left[, c(left_id, vars), drop = FALSE],
by.x = names(matched_pairs)[1],
by.y = left_id,
all.x = TRUE
)
right_matched <- merge(
matched_pairs[, 2, drop = FALSE],
right[, c(right_id, vars), drop = FALSE],
by.x = names(matched_pairs)[2],
by.y = right_id,
all.x = TRUE
)
# Calculate balance for each variable
var_stats_list <- lapply(vars, function(v) {
left_vals <- left_matched[[v]]
right_vals <- right_matched[[v]]
calculate_var_balance(left_vals, right_vals, v)
})
# Convert to tibble
var_stats <- tibble::as_tibble(do.call(rbind, lapply(var_stats_list, as.data.frame)))
# Overall balance metrics
abs_std_diffs <- abs(var_stats$std_diff)
mean_abs_std_diff <- mean(abs_std_diffs, na.rm = TRUE)
max_abs_std_diff <- if (all(is.na(abs_std_diffs))) NA_real_ else max(abs_std_diffs, na.rm = TRUE)
pct_large_imbalance <- mean(abs_std_diffs > 0.25, na.rm = TRUE) * 100
overall <- list(
mean_abs_std_diff = mean_abs_std_diff,
max_abs_std_diff = max_abs_std_diff,
pct_large_imbalance = pct_large_imbalance,
n_vars = length(vars)
)
# Check for blocking
has_blocks <- "block_id" %in% names(pairs)
block_stats <- NULL
if (has_blocks) {
# Calculate per-block statistics
blocks <- unique(matched_pairs$block_id)
blocks <- blocks[!is.na(blocks)]
block_stats_list <- lapply(blocks, function(b) {
block_pairs <- matched_pairs[matched_pairs$block_id == b, ]
# Get left and right values for this block
left_block <- merge(
block_pairs[, 1, drop = FALSE],
left[, c(left_id, vars), drop = FALSE],
by.x = names(block_pairs)[1],
by.y = left_id
)
right_block <- merge(
block_pairs[, 2, drop = FALSE],
right[, c(right_id, vars), drop = FALSE],
by.x = names(block_pairs)[2],
by.y = right_id
)
# Calculate mean std diff for this block
block_var_stats <- lapply(vars, function(v) {
calculate_var_balance(left_block[[v]], right_block[[v]], v)
})
mean_std_diff <- mean(abs(sapply(block_var_stats, function(x) x$std_diff)), na.rm = TRUE)
# Determine quality (handle NA/NaN from empty blocks)
quality <- if (is.na(mean_std_diff) || is.nan(mean_std_diff)) {
"Unknown"
} else if (mean_std_diff < 0.1) {
"Excellent"
} else if (mean_std_diff < 0.25) {
"Good"
} else if (mean_std_diff < 0.5) {
"Fair"
} else {
"Poor"
}
# Find worst variable
std_diffs <- sapply(block_var_stats, function(x) abs(x$std_diff))
worst_var_idx <- which.max(std_diffs)
worst_var <- if (length(worst_var_idx) > 0) vars[worst_var_idx] else NA_character_
data.frame(
block_id = b,
n_pairs = nrow(block_pairs),
mean_std_diff = mean_std_diff,
quality = quality,
worst_var = worst_var,
stringsAsFactors = FALSE
)
})
block_stats <- tibble::as_tibble(do.call(rbind, block_stats_list))
}
# Build result object
result_obj <- list(
var_stats = var_stats,
overall = overall,
pairs = matched_pairs,
n_matched = n_matched,
n_unmatched_left = n_unmatched_left,
n_unmatched_right = n_unmatched_right,
n_total_left = n_total_left,
n_total_right = n_total_right,
method = result$info$method %||% "unknown",
has_blocks = has_blocks,
block_stats = block_stats
)
class(result_obj) <- "balance_diagnostics"
return(result_obj)
}
#' Create Balance Table
#'
#' Formats balance diagnostics into a clean table for display or export.
#'
#' @param balance A balance_diagnostics object from \code{balance_diagnostics()}
#' @param digits Number of decimal places for rounding (default: 3)
#'
#' @return A tibble with formatted balance statistics
#'
#' @export
balance_table <- function(balance, digits = 3) {
if (!inherits(balance, "balance_diagnostics")) {
stop("balance must be a balance_diagnostics object")
}
# Format variable statistics
tbl <- balance$var_stats[, c("variable", "mean_left", "mean_right", "mean_diff",
"std_diff", "var_ratio", "ks_statistic")]
# Round numeric columns
tbl$mean_left <- round(tbl$mean_left, digits)
tbl$mean_right <- round(tbl$mean_right, digits)
tbl$mean_diff <- round(tbl$mean_diff, digits)
tbl$std_diff <- round(tbl$std_diff, digits)
tbl$var_ratio <- round(tbl$var_ratio, digits)
tbl$ks_statistic <- round(tbl$ks_statistic, digits)
# Rename columns
names(tbl) <- c("Variable", "Mean Left", "Mean Right", "Mean Diff",
"Std Diff", "Var Ratio", "KS Stat")
return(tibble::as_tibble(tbl))
}
#' Print Method for Balance Diagnostics
#'
#' @param x A balance_diagnostics object
#' @param ... Additional arguments (ignored)
#'
#' @return Invisibly returns the input object `x`.
#' @export
print.balance_diagnostics <- function(x, ...) {
cat("\n")
cat("Balance Diagnostics for Matched Pairs\n")
cat("======================================\n\n")
# Matching summary
cat("Matching Summary:\n")
cat(sprintf(" Method: %s\n", x$method))
cat(sprintf(" Matched pairs: %d\n", x$n_matched))
cat(sprintf(" Unmatched left: %d (of %d)\n", x$n_unmatched_left, x$n_total_left))
cat(sprintf(" Unmatched right: %d (of %d)\n", x$n_unmatched_right, x$n_total_right))
cat("\n")
# Variable-level statistics
cat("Variable-level Balance:\n")
print(balance_table(x, digits = 3), n = Inf)
cat("\n")
# Overall assessment
cat("Overall Balance:\n")
cat(sprintf(" Mean |Std Diff|: %.3f ", x$overall$mean_abs_std_diff))
if (x$overall$mean_abs_std_diff < 0.1) {
cat("(Excellent)\n")
} else if (x$overall$mean_abs_std_diff < 0.25) {
cat("(Good)\n")
} else if (x$overall$mean_abs_std_diff < 0.5) {
cat("(Acceptable)\n")
} else {
cat("(Poor)\n")
}
cat(sprintf(" Max |Std Diff|: %.3f\n", x$overall$max_abs_std_diff))
cat(sprintf(" Vars with |Std Diff| > 0.25: %.1f%%\n", x$overall$pct_large_imbalance))
cat("\n")
# Block summaries if present
if (x$has_blocks && !is.null(x$block_stats)) {
cat("Block-level Balance:\n")
print(x$block_stats, n = Inf)
cat("\n")
}
cat("Balance Interpretation:\n")
cat(" |Std Diff| < 0.10: Excellent balance\n")
cat(" |Std Diff| 0.10-0.25: Good balance\n")
cat(" |Std Diff| 0.25-0.50: Acceptable balance\n")
cat(" |Std Diff| > 0.50: Poor balance\n")
cat("\n")
invisible(x)
}
#' Summary method for balance diagnostics
#'
#' @param object A balance_diagnostics object
#' @param ... Additional arguments (ignored)
#'
#' @return A list containing summary statistics (invisibly)
#' @export
#' @method summary balance_diagnostics
summary.balance_diagnostics <- function(object, ...) {
std_diffs <- abs(object$var_stats$std_diff)
var_ratios <- object$var_stats$var_ratio
# Classify balance quality
quality <- if (object$overall$mean_abs_std_diff < 0.1) {
"Excellent"
} else if (object$overall$mean_abs_std_diff < 0.25) {
"Good"
} else if (object$overall$mean_abs_std_diff < 0.5) {
"Acceptable"
} else {
"Poor"
}
out <- list(
n_matched = object$n_matched,
n_unmatched_left = object$n_unmatched_left,
n_unmatched_right = object$n_unmatched_right,
n_variables = nrow(object$var_stats),
mean_abs_std_diff = object$overall$mean_abs_std_diff,
max_abs_std_diff = object$overall$max_abs_std_diff,
median_abs_std_diff = stats::median(std_diffs, na.rm = TRUE),
pct_excellent = mean(std_diffs < 0.1, na.rm = TRUE) * 100,
pct_good = mean(std_diffs >= 0.1 & std_diffs < 0.25, na.rm = TRUE) * 100,
pct_acceptable = mean(std_diffs >= 0.25 & std_diffs < 0.5, na.rm = TRUE) * 100,
pct_poor = mean(std_diffs >= 0.5, na.rm = TRUE) * 100,
mean_var_ratio = mean(var_ratios, na.rm = TRUE),
quality = quality,
has_blocks = object$has_blocks,
method = object$method
)
class(out) <- "summary.balance_diagnostics"
out
}
#' @export
print.summary.balance_diagnostics <- function(x, ...) {
cat("Balance Diagnostics Summary\n")
cat("===========================\n\n")
cat("Method:", x$method, "\n")
cat("Matched pairs:", x$n_matched, "\n")
cat("Variables assessed:", x$n_variables, "\n\n")
cat("Standardized Differences:\n")
cat(" Mean |Std Diff|:", sprintf("%.3f", x$mean_abs_std_diff), "\n")
cat(" Median |Std Diff|:", sprintf("%.3f", x$median_abs_std_diff), "\n")
cat(" Max |Std Diff|:", sprintf("%.3f", x$max_abs_std_diff), "\n\n")
cat("Balance Quality Distribution:\n")
cat(" Excellent (<0.10):", sprintf("%.1f%%", x$pct_excellent), "\n")
cat(" Good (0.10-0.25):", sprintf("%.1f%%", x$pct_good), "\n")
cat(" Acceptable (0.25-0.50):", sprintf("%.1f%%", x$pct_acceptable), "\n")
cat(" Poor (>0.50):", sprintf("%.1f%%", x$pct_poor), "\n\n")
cat("Overall Quality:", x$quality, "\n")
invisible(x)
}
#' Plot method for balance diagnostics
#'
#' Produces a Love plot (dot plot) of standardized differences.
#'
#' @param x A balance_diagnostics object
#' @param type Type of plot: "love" (default), "histogram", or "variance"
#' @param threshold Threshold line for standardized differences (default: 0.1)
#' @param ... Additional arguments passed to plotting functions
#'
#' @return The balance_diagnostics object (invisibly)
#' @export
#' @method plot balance_diagnostics
plot.balance_diagnostics <- function(x,
type = c("love", "histogram", "variance"),
threshold = 0.1,
...) {
type <- match.arg(type)
switch(type,
love = {
# Love plot (dot plot of standardized differences)
std_diffs <- x$var_stats$std_diff
vars <- x$var_stats$variable
n <- length(std_diffs)
# Order by absolute value
ord <- order(abs(std_diffs), decreasing = TRUE)
std_diffs <- std_diffs[ord]
vars <- vars[ord]
# Set up plot area
old_par <- graphics::par(mar = c(4, 8, 3, 2))
on.exit(graphics::par(old_par))
# Plot
graphics::plot(std_diffs, seq_len(n),
xlim = range(c(-max(abs(std_diffs)) * 1.1, max(abs(std_diffs)) * 1.1, -threshold, threshold)),
yaxt = "n",
xlab = "Standardized Difference",
ylab = "",
main = "Balance: Standardized Differences",
pch = 19,
col = ifelse(abs(std_diffs) > threshold, "firebrick", "steelblue"),
...)
graphics::axis(2, at = seq_len(n), labels = vars, las = 1, cex.axis = 0.8)
graphics::abline(v = 0, lty = 1, col = "gray50")
graphics::abline(v = c(-threshold, threshold), lty = 2, col = "firebrick")
},
histogram = {
# Histogram of standardized differences
std_diffs <- abs(x$var_stats$std_diff)
graphics::hist(std_diffs,
main = "Distribution of |Standardized Differences|",
xlab = "|Standardized Difference|",
col = "steelblue",
border = "white",
...)
graphics::abline(v = threshold, lty = 2, col = "firebrick", lwd = 2)
},
variance = {
# Variance ratio plot
var_ratios <- x$var_stats$var_ratio
vars <- x$var_stats$variable
n <- length(var_ratios)
# Order by distance from 1
ord <- order(abs(log(var_ratios)), decreasing = TRUE)
var_ratios <- var_ratios[ord]
vars <- vars[ord]
old_par <- graphics::par(mar = c(4, 8, 3, 2))
on.exit(graphics::par(old_par))
graphics::plot(var_ratios, seq_len(n),
xlim = c(0.5, max(2, max(var_ratios, na.rm = TRUE) * 1.1)),
yaxt = "n",
xlab = "Variance Ratio (left/right)",
ylab = "",
main = "Balance: Variance Ratios",
pch = 19,
col = ifelse(var_ratios < 0.5 | var_ratios > 2, "firebrick", "steelblue"),
log = "x",
...)
graphics::axis(2, at = seq_len(n), labels = vars, las = 1, cex.axis = 0.8)
graphics::abline(v = 1, lty = 1, col = "gray50")
graphics::abline(v = c(0.5, 2), lty = 2, col = "firebrick")
}
)
invisible(x)
}
# Helper for NULL coalescing
`%||%` <- function(x, y) {
if (is.null(x)) y else x
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.