#' Fit Augmented SCM with multiple outcomes
#' @param form outcome ~ treatment | auxillary covariates
#' @param unit Name of unit column
#' @param time Name of time column
#' @param t_int Time of intervention
#' @param data Panel data as dataframe
#' @param progfunc What function to use to impute control outcomes
#' Ridge=Ridge regression (allows for standard errors),
#' None=No outcome model,
#' @param scm Whether the SCM weighting function is used
#' @param fixedeff Whether to include a unit fixed effect, default F
#' @param cov_agg Covariate aggregation functions, if NULL then use mean with NAs omitted
#' @param combine_method How to combine outcomes: `concat` concatenates outcomes and `avg` averages them, default: 'avg'
#' @param ... optional arguments for outcome model
#'
#' @return augsynth object that contains:
#' \itemize{
#' \item{"weights"}{Ridge ASCM weights}
#' \item{"l2_imbalance"}{Imbalance in pre-period outcomes, measured by the L2 norm}
#' \item{"scaled_l2_imbalance"}{L2 imbalance scaled by L2 imbalance of uniform weights}
#' \item{"mhat"}{Outcome model estimate}
#' \item{"data"}{Panel data as matrices}
#' }
#' @export
augsynth_multiout <- function(form, unit, time, t_int, data,
progfunc=c("Ridge", "None"),
scm=T,
fixedeff = FALSE,
cov_agg=NULL,
combine_method = "avg",
...) {
call_name <- match.call()
form <- Formula::Formula(form)
unit <- enquo(unit)
time <- enquo(time)
## format data
outcome <- terms(formula(form, rhs=1))[[2]]
trt <- terms(formula(form, rhs=1))[[3]]
outcomes_str <- all.vars(outcome)
outcomes <- sapply(outcomes_str, quo)
# get outcomes as a list
wide_list <- format_data_multi(outcomes, trt, unit, time, t_int, data)
## add covariates
if(length(form)[2] == 2) {
cov_form <- paste(deparse(terms(formula(form, rhs = 2))[[3]]), collapse = "")
new_form <- as.formula(paste("~", cov_form))
Z <- extract_covariates(new_form, unit, time, t_int, data, cov_agg)
} else {
Z <- NULL
}
# only allow ridge augmentation
if(! tolower(progfunc) %in% c("none", "ridge")) {
stop(paste(progfunc, "is not a valid augmentation function with multiple outcomes. Only `none` or `ridge` are allowable options for `prog_func`"))
}
# fit augmented SCM
augsynth <- fit_augsynth_multiout_internal(wide_list, combine_method, Z,
progfunc, scm,
fixedeff, outcomes_str, ...)
# add some extra data
augsynth$data$time <- data %>% distinct(!!time) %>% pull(!!time)
augsynth$call <- call_name
augsynth$t_int <- t_int
augsynth$combine_method <- combine_method
treated_units <- data %>% filter(!!trt == 1) %>% distinct(!!unit) %>% pull(!!unit)
control_units <- data %>% filter(!(!!unit %in% treated_units)) %>%
distinct(!!unit) %>% pull(!!unit)
augsynth$weights <- matrix(augsynth$weights)
rownames(augsynth$weights) <- control_units
return(augsynth)
}
#' Internal function to fit augmented SCM with multiple outcomes
#' @param wide_list List of matrices for each outcome formatted from format_data
#' @param combine_method How to combine outcomes
#' @param Z Matrix of auxiliary covariates
#' @param progfunc outcome model to use
#' @param scm Whether to fit SCM
#' @param fixedeff Whether to de-mean synth
#' @param ... Extra args for outcome model
#' @noRd
fit_augsynth_multiout_internal <- function(wide_list, combine_method, Z,
progfunc, scm, fixedeff,
outcomes_str, ...) {
# combine into a matrix for fitting and balancing
out <- combine_outcomes(wide_list, combine_method, fixedeff, ...)
wide_bal <- out$wide_bal
mhat <- out$mhat
V <- out$V
synth_data <- do.call(format_synth, wide_bal)
# set Y1 and Y0plot to be raw concatenated outcomes
X <- do.call(cbind, wide_list$X)
y <- do.call(cbind, wide_list$y)
trt <- wide_list$trt
synth_data$Y0plot <- t(cbind(X, y)[trt == 0,, drop = F])
synth_data$Y1plot <- colMeans(cbind(X, y)[trt == 1,, drop = F])
augsynth <- fit_augsynth_internal(wide_bal, synth_data, Z, progfunc,
scm, fixedeff, V = V, ...)
# potentially add back in fixed effects
augsynth$mhat <- mhat# + augsynth$mhat
augsynth$data = list(X = X, trt = trt, y = y, Z = Z)
augsynth$data_list <- wide_list
augsynth$outcomes <- outcomes_str
##format output
class(augsynth) <- c("augsynth_multiout", "augsynth")
return(augsynth)
}
#' Helper function to combine multiple outcomes into a single balance matrix
#' @param wide_list List of lists of pre/post treatment data for each outcome
#' @param combine_method How to combine outcomes
#' @param fixedeff Whether to take out unit fixed effects or not
#' @param nu Weighting between concatenated and averaged objectives
#' @param ... Extra arguments for combination
#' @noRd
#' @return \itemize{
#' \item{"X"}{Matrix of combined pre-treatment outcomes}
#' \item{"trt"}{Vector of treatment assignments}
#' \item{"y"}{Matrix of combined post-treatment outcomes}
#' }
combine_outcomes <- function(wide_list, combine_method, fixedeff,
nu = NULL, ...) {
n_outs <- length(wide_list$X)
n_units <- Map(nrow, wide_list$X) %>% Reduce(max, .)
# take out unit fixed effects
demean_j <- function(j) {
means <- rowMeans(wide_list$X[[j]], na.rm = TRUE)
new_wide_data <- list()
new_X <- wide_list$X[[j]] - means
new_y <- wide_list$y[[j]] - means
new_wide_data$X <- new_X
new_wide_data$y <- new_y
new_wide_data$mhat_pre <- replicate(ncol(wide_list$X[[j]]),
means)
new_wide_data$mhat_post <- replicate(ncol(wide_list$y[[j]]),
means)
return(new_wide_data)
}
if(fixedeff) {
new_wide_list <- lapply(1:n_outs, demean_j)
wide_list$X <- lapply(new_wide_list, function(x) x$X)
wide_list$y <- lapply(new_wide_list, function(x) x$y)
mhat_pre <- lapply(new_wide_list, function(x) x$mhat_pre)
mhat_post <- lapply(new_wide_list, function(x) x$mhat_post)
} else {
mhat_pre <- lapply(
1:n_outs,
function(j) matrix(0, nrow = n_units, ncol = ncol(wide_list$X[[j]])))
mhat_post <- lapply(
1:n_outs,
function(j) matrix(0, nrow = n_units, ncol = ncol(wide_list$y[[j]])))
}
# combine outcomes
if(combine_method == "concat") {
# center X and scale by overall variance for outcome
# X <- lapply(wide_list$X, function(x) t(t(x) - colMeans(x)) / sd(x))
wide_bal <- list(X = do.call(cbind, lapply(wide_list$X, function(x) t(na.omit(t(x))))),
y = do.call(cbind, lapply(wide_list$y, function(x) t(na.omit(t(x))))),
trt = wide_list$trt)
# V matrix scales by inverse variance for outcome and number of periods
V <- do.call(c,
lapply(wide_list$X,
function(x) rep(1 / (sqrt(nrow(na.omit(t(x)))) *
sd(x[wide_list$trt == 0, , drop = F], na.rm=T)),
nrow(na.omit(t(x))))))
# } else if(combine_method == "svd") {
# wide_bal <- list(X = do.call(cbind, wide_list$X),
# y = do.call(cbind, wide_list$y),
# trt = wide_list$trt)
# # first get the standard deviations of the outcomes to put on the same scale
# sds <- do.call(c,
# lapply(wide_list$X,
# function(x) rep((sqrt(ncol(x)) * sd(x, na.rm=T)), ncol(x))))
# # do an SVD on centered and scaled outcomes
# X0 <- wide_bal$X[wide_bal$trt == 0, , drop = FALSE]
# X0 <- t((t(X0) - colMeans(X0)) / sds)
# k <- if(is.null(k)) ncol(X0) else k
# V <- diag(1 / sds) %*% svd(X0)$v[, 1:k, drop = FALSE]
} else if(combine_method == "avg") {
# average pre-treatment outcomes, dividing by standard deviation and removing missing values
X_avg <- rowMeans(simplify2array(lapply(wide_list$X,
function(x) (x - mean(x[wide_list$trt == 0,], na.rm = TRUE)) / sd(x[wide_list$trt == 0,], na.rm = TRUE))), dims = 2, na.rm = TRUE)
# remove any time periods with NAs
X_avg <- t(na.omit(t(X_avg)))
wide_bal <- list(X = X_avg,
y = rowMeans(simplify2array(wide_list$y), dims = 2, na.rm = TRUE),
trt = wide_list$trt)
V <- diag(ncol(wide_bal$X))
} else if(combine_method == "avg_concat") {
# average pre-treatment outcomes, dividing by standard deviation and removing missing values
# standardize the outcomes
X_list_std<- lapply(wide_list$X,function(x) (x - mean(x[wide_list$trt == 0,], na.rm = TRUE)) / sd(x[wide_list$trt == 0,], na.rm = TRUE))
X_avg <- rowMeans(simplify2array(X_list_std), dims = 2, na.rm = TRUE)
# remove any time periods with NAs
X_avg <- t(na.omit(t(X_avg)))
X_concat <- do.call(cbind, lapply(X_list_std, function(x) t(na.omit(t(x)))))
# V matrix assigns weight nu to the averaged objective and (1 - nu) to the concatenated objective
# V <- c(rep(sqrt(nu), ncol(X_avg)),
# sqrt(1 - nu) / sqrt(n_outs) * do.call(c,
# lapply(wide_list$X,
# function(x) rep(1 / (sqrt(nrow(na.omit(t(x)))) *
# sd(x[wide_list$trt == 0, , drop = F], na.rm=T)),
# nrow(na.omit(t(x))))))
# )
V <- c(rep(sqrt(nu), ncol(X_avg)), rep(sqrt(1 - nu) / sqrt(n_outs), ncol(X_concat)))
wide_bal <- list(
X = cbind(X_avg, X_concat),
y = do.call(cbind, lapply(wide_list$y, function(x) t(na.omit(t(x))))),
trt = wide_list$trt
)
} else {
stop(paste("combine_method should be one of ('avg', 'concat', 'avg_concat'),",
combine_method, " is not a valid combining option"))
}
mhat_pre <- do.call(cbind, mhat_pre)
mhat_post <- do.call(cbind, mhat_post)
mhat <- cbind(mhat_pre, mhat_post)
return(list(wide_bal = wide_bal, mhat = mhat, V = V))
}
#' Get prediction of ATT or average outcome under control
#' @param object augsynth_multiout object
#' @param ... Optional arguments, including \itemize{\item{"att"}{Whether to return the ATT or average outcome under control}}
#'
#' @return Vector of predicted post-treatment control averages
#' @export
predict.augsynth_multiout <- function(object, ...) {
if ("att" %in% names(list(...))) {
att <- list(...)$att
} else {
att <- F
}
# call augsynth predict
pred <- NextMethod()
# separate out by outcome
n_outs <- length(object$data_list$X)
max_t <- max(sapply(1:n_outs,
function(k) ncol(object$data_list$X[[k]]) + ncol(object$data_list$y[[k]])))
pred_reshape <- matrix(NA, ncol = n_outs,
nrow = max_t)
colnames <- lapply(1:n_outs,
function(k) colnames(cbind(object$data_list$X[[k]],
object$data_list$y[[k]])))
rownames(pred_reshape) <- colnames[[which.max(sapply(colnames, length))]]
colnames(pred_reshape) <- object$outcomes
# get outcome names for predictions
pre_outs <- do.call(c,
sapply(1:n_outs,
function(j) {
rep(object$outcomes[j],
ncol(object$data_list$X[[j]]))
}, simplify = FALSE))
post_outs <- do.call(c,
sapply(1:n_outs,
function(j) {
rep(object$outcomes[j],
ncol(object$data_list$y[[j]]))
}, simplify = FALSE))
# print(pred)
# print(cbind(names(pred), c(pre_outs, post_outs)))
pred_reshape[cbind(names(pred), c(pre_outs, post_outs))] <- pred
return(pred_reshape)
}
#' Print function for augsynth
#' @param x augsynth_multiout object
#' @param ... Optional arguments
#' @export
print.augsynth_multiout <- function(x, ...) {
## straight from lm
cat("\nCall:\n", paste(deparse(x$call), sep="\n", collapse="\n"), "\n\n", sep="")
## print att estimates
att <- predict(x, att = T)
att_post <- data.frame(
colMeans(att[as.numeric(rownames(att)) >= x$t_int,, drop = F]))
names(att_post) <- c("")
cat("Average ATT Estimate:\n")
print(att_post)
cat("\n\n")
}
#' Summary function for augsynth
#' @param object augsynth_multiout object
#' @param inf whether or not to perform inference
#' @param inf_typ Type of inference, default is "conformal"
#' @param grid_size Grid to compute prediction intervals over, default is 1 and only p-values are computed
#' @param ... Optional arguments, including \itemize{\item{"se"}{Whether to plot standard error}}
#' @export
summary.augsynth_multiout <- function(object, inf = T, inf_type = "conformal", grid_size = 1, ...) {
summ <- list()
if(inf) {
if(inf_type == "conformal") {
if(grid_size > 1) {
warning(paste0("A grid size of ", grid_size, " will require ",
grid_size, "^", length(object$outcomes),
" = ", grid_size ^ length(object$outcomes),
" evaluations. This could take a while..."))
}
att_se <- conformal_inf_multiout(object, ...)
} else {
stop("Only conformal inference is supported for multiple outcomes")
}
# if(inf_type == "jackknife") {
# att_se <- jackknife_se_multiout(object)
# } else if(inf_type == "jackknife+") {
# att_se <- time_jackknife_plus_multiout(object, ...)
# } else if(inf_type == "conformal") {
# att_se <- conformal_inf_multiout(object, ...)
# } else {
# stop(paste(inf_type, "is not a valid choice of 'inf_type'"))
# }
t_final <- nrow(att_se$att)
att_df <- data.frame(att_se$att[1:(t_final - 1),, drop=F])
names(att_df) <- object$outcomes
att_df$Time <- object$data$time
att_df <- att_df %>% gather(Outcome, Estimate, -Time)
# if(inf_type == "jackknife") {
# se_df <- data.frame(att_se$se[1:(t_final - 1),, drop=F])
# names(se_df) <- object$outcomes
# se_df$Time <- object$data$time
# se_df <- se_df %>% gather(Outcome, Std.Error, -Time)
# att <- inner_join(att_df, se_df, by = c("Time", "Outcome"))
# } else if(inf_type %in% c("conformal", "jackknife+")) {
lb_df <- data.frame(att_se$lb[1:(t_final - 1),, drop=F])
names(lb_df) <- object$outcomes
lb_df$Time <- object$data$time
lb_df <- lb_df %>% gather(Outcome, lower_bound, -Time)
ub_df <- data.frame(att_se$ub[1:(t_final - 1),, drop=F])
names(ub_df) <- object$outcomes
ub_df$Time <- object$data$time
ub_df <- ub_df %>% gather(Outcome, upper_bound, -Time)
att <- inner_join(att_df, lb_df, by = c("Time", "Outcome")) %>%
inner_join(ub_df, by = c("Time", "Outcome"))
# if(inf_type == "conformal") {
pval_df <- data.frame(att_se$p_val[1:(t_final - 1),, drop=F])
names(pval_df) <- object$outcomes
pval_df$Time <- object$data$time
pval_df <- pval_df %>% gather(Outcome, p_val, -Time)
att <- inner_join(att, pval_df, by = c("Time", "Outcome"))
# }
# }
if(grid_size == 1) {
att <- att %>% mutate(lower_bound = NA, upper_bound = NA)
}
att_avg <- data.frame(att_se$att[t_final,, drop = F])
names(att_avg) <- object$outcomes
att_avg <- gather(att_avg, Outcome, Estimate)
# if(inf_type == "jackknife") {
# att_avg_se <- data.frame(att_se$se[t_final,, drop = F])
# names(att_avg_se) <- object$outcomes
# att_avg_se <- gather(att_avg_se, Outcome, Std.Error)
# average_att <- inner_join(att_avg, att_avg_se, by="Outcome")
# } else if(inf_type %in% c("conformal", "jackknife+")){
att_avg_lb <- data.frame(att_se$lb[t_final,, drop = F])
names(att_avg_lb) <- object$outcomes
att_avg_lb <- gather(att_avg_lb, Outcome, lower_bound)
att_avg_ub <- data.frame(att_se$ub[t_final,, drop = F])
names(att_avg_ub) <- object$outcomes
att_avg_ub <- gather(att_avg_ub, Outcome, upper_bound)
average_att <- inner_join(att_avg, att_avg_lb, by="Outcome") %>%
inner_join(att_avg_ub, by = "Outcome")
# if(inf_type == "conformal") {
att_avg_pval <- data.frame(att_se$p_val[t_final,, drop = F])
names(att_avg_pval) <- object$outcomes
att_avg_pval <- gather(att_avg_pval, Outcome, p_val)
average_att <- inner_join(average_att, att_avg_pval, by = "Outcome")
if(grid_size == 1) {
average_att <- average_att %>% mutate(lower_bound = NA, upper_bound = NA)
}
# }
# } else {
# average_att <- gather(att_avg, Outcome, Estimate)
# }
} else {
att_est <- predict(object, att = T)
att_df <- data.frame(att_est)
names(att_df) <- object$outcomes
att_df$Time <- object$data$time
att <- att_df %>% gather(Outcome, Estimate, -Time)
att$Std.Error <- NA
t_int <- min(sapply(object$data_list$X, ncol))
att_avg <- data.frame(t(colMeans(
att_est[t_int:nrow(att_est),, drop = F])))
print(att_avg)
names(att_avg) <- object$outcomes
average_att <- gather(att_avg, Outcome, Estimate)
average_att$Std.Error <- NA
}
# get average of all outcomes
sds <- data.frame(Outcome = object$outcomes,
sdo = sapply(object$data_list$X,
function(x) sd(x[object$data_list$trt == 0,], na.rm = TRUE)))
att %>%
inner_join(sds, by = "Outcome") %>%
mutate(Estimate = Estimate / sdo) %>%
group_by(Time) %>%
summarise(Estimate = mean(Estimate, na.rm = TRUE)) %>%
mutate(Outcome = "Average") %>%
bind_rows(att, .) -> att
summ$att <- att
summ$average_att <- average_att
summ$t_int <- object$t_int
summ$call <- object$call
summ$l2_imbalance <- object$l2_imbalance
summ$scaled_l2_imbalance <- object$scaled_l2_imbalance
summ$inf_type <- inf_type
## get estimated bias
if(object$progfunc == "Ridge") {
mhat <- object$ridge_mhat
w <- object$synw
} else {
mhat <- object$mhat
w <- object$weights
}
trt <- object$data$trt
m1 <- colMeans(mhat[trt==1,,drop=F])
summ$bias_est <- m1 - t(mhat[trt==0,,drop=F]) %*% w
if(object$progfunc == "None" | (!object$scm)) {
summ$bias_est <- NA
}
class(summ) <- "summary.augsynth_multiout"
return(summ)
}
#' Print function for summary function for augsynth
#' @param x summary.augsynth_multiout object
#' @param ... Optional arguments
#' @export
print.summary.augsynth_multiout <- function(x, ...) {
## straight from lm
cat("\nCall:\n", paste(deparse(x$call), sep="\n", collapse="\n"), "\n\n", sep="")
att_est <- x$att$Estimate
## get pre-treatment fit by outcome
imbal <- x$att %>%
filter(Time < x$t_int) %>%
group_by(Outcome) %>%
summarise(Pre.RMSE = sqrt(mean(Estimate ^ 2, na.rm = TRUE)))
cat(paste("Overall L2 Imbalance (Scaled):",
format(round(x$l2_imbalance,3), nsmall=3), " (",
format(round(x$scaled_l2_imbalance,3), nsmall=3), ")\n\n",
# "Avg Estimated Bias: ",
# format(round(mean(summ$bias_est), 3),nsmall=3), "\n\n",
sep=""))
cat("Average ATT Estimate:\n")
print(inner_join(x$average_att, imbal, by = "Outcome"))
cat("\n\n")
}
#' Plot function for summary function for augsynth
#' @importFrom graphics plot
#' @param x summary.augsynth_multiout object
#' @param inf Boolean, whether to plot uncertainty intervals, default TRUE
#' @param plt_avg Boolean, whether to plot the average of the outcomes, default FALSE
#' @param ... Optional arguments for summary function
#'
#' @export
plot.augsynth_multiout <- function(x, inf = T, plt_avg = F, ...) {
plot(summary(x, ...), inf = inf, plt_avg = plt_avg)
}
#' Plot function for summary function for augsynth
#' @param x summary.augsynth_multiout object
#' @param inf Boolean, whether to plot uncertainty intervals, default TRUE
#' @param plt_avg Boolean, whether to plot the average of the outcomes, default FALSE
#'
#' @export
plot.summary.augsynth_multiout <- function(x, inf = T, plt_avg = F, ...) {
if(plt_avg) {
p <- x$att %>%
ggplot2::ggplot(ggplot2::aes(x=Time, y=Estimate))
} else {
p <- x$att %>%
filter(Outcome != "Average") %>%
ggplot2::ggplot(ggplot2::aes(x=Time, y=Estimate))
}
if(inf) {
if(x$inf_type == "jackknife") {
p <- p + ggplot2::geom_ribbon(ggplot2::aes(ymin=Estimate-2*Std.Error,
ymax=Estimate+2*Std.Error),
alpha=0.2, data = . %>% filter(Outcome != "Average"))
} else if(x$inf_type %in% c("conformal", "jackknife+")) {
p <- p + ggplot2::geom_ribbon(ggplot2::aes(ymin=lower_bound,
ymax=upper_bound),
alpha=0.2, data = . %>% filter(Outcome != "Average"))
}
}
p + ggplot2::geom_line() +
ggplot2::geom_vline(xintercept=x$t_int, lty=2) +
ggplot2::geom_hline(yintercept=0, lty=2) +
ggplot2::facet_wrap(~ Outcome, scales = "free_y") +
ggplot2::theme_bw()
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.