#' @title Visualize ATE Estimates by Contour Plot
#'
#' @description Visualize the same level of Average Treatment Effect (ATE)
#' for different choices of sensitivity parameters with contour lines.
#' @usage contour_ate(x_trt, y_trt, x_ctrl, y_ctrl,
#' largest_effect, gamma_length = 11, ...)
#' @param x_trt a \code{tibble} or data frame with observed pre-treatment variables for the treatment group
#' @param y_trt a vector with outcomes for the treatment group
#' @param x_ctrl a \code{tibble} or data frame with observed pre-treatment variables for the control group
#' @param y_ctrl a vector with outcomes for the control group
#' @param largest_effect the largest magnitude of sensitivity parameter to be considered, chosen from \code{\link{caliplot}}
#' @param gamma_length chosen length of sensitivity parameter sequence, which needs to be an odd integer
#' @param joint logical. If TURE, the mean surface and residual variance will be estimated jointly for both treatment
#' groups; if FALSE (default), the mean surface and residual variance will be estimated independently for
#' each treatment group.
#' @param ... arguments passed to the function \code{BART::wbart}
#' @section Details:
#' For analysis details, please see \code{\link{heatmap_ate}}
#'
#' @export
#'
#' @examples
#' # Observed data in treatment group
#' NHANES_trt <- NHANES %>% dplyr::filter(trt_dbp == 1)
#' x_trt <- NHANES_trt %>% select(-one_of("trt_dbp", "ave_dbp"))
#' y_trt <- NHANES_trt %>% select(ave_dbp)
#'
#' # Observed data in control group
#' NHANES_ctrl <- NHANES %>% dplyr::filter(trt_dbp == 0)
#' x_ctrl <- NHANES_ctrl %>% select(-one_of("trt_dbp", "ave_dbp"))
#' y_ctrl <- NHANES_ctrl %>% select(ave_dbp)
#'
#' # ATE Contour Plot
#' contour_ate(x_trt, y_trt, x_ctrl, y_ctrl, largest_effect = 0.05)
#' contour_ate(x_trt, y_trt, x_ctrl, y_ctrl, largest_effect = 0.05, joint = TRUE)
contour_ate = function(x_trt, y_trt, x_ctrl, y_ctrl, largest_effect, gamma_length = 11, joint = FALSE, ...){
if(joint){
## prepare joint dataset ##
x_joint = rbind(x_trt, x_ctrl)
trt = c(rep(1, nrow(x_trt)), rep(0, nrow(x_ctrl)))
x_train_joint = cbind(x_joint, trt) ## trt as a predictor when joint = T
x_test_joint = cbind(x_joint, 1-trt)
y_train_joint = rbind(y_trt, y_ctrl)
## jointly fit the model using BART ##
joint_bart_fit <- BART::wbart(as.matrix(x_train_joint), as.matrix(y_train_joint), x.test=as.matrix(x_test_joint), ...)
mu_ctrl_obs_joint <- joint_bart_fit$yhat.train[ , trt == 0] ## Y.hat(Y(0)|T=0), observed
mu_ctrl_test_joint <- joint_bart_fit$yhat.test[ , trt == 1] ## Y.hat(Y(1)|T=0), missing
mu_trt_obs_joint <- joint_bart_fit$yhat.train[ , trt == 1] ## Y.hat(Y(1)|T=1), observed
mu_trt_test_joint <- joint_bart_fit$yhat.test[ , trt == 0] ## Y.hat(Y(0)|T=1), missing
sigma_joint <- joint_bart_fit$sigma[101:1100]
nctrl <- ncol(mu_ctrl_obs_joint)
ntreat <- ncol(mu_ctrl_test_joint)
q <- rbeta(1000, ntreat + 1, nctrl + 1)
gamma_0 <- c(seq(from = -largest_effect, to = 0, length.out = (gamma_length + 1) / 2),
seq(from= 0, to = largest_effect, length.out = (gamma_length + 1) / 2)[-1])
gamma_1 <- gamma_0
gamma_grid <- expand.grid(gamma_0, gamma_1)
colnames(gamma_grid) = c("gamma_0", "gamma_1")
att_joint <- mu_trt_obs_joint %o% rep(1, nrow(gamma_grid)) - mu_ctrl_test_joint %o% rep(1, nrow(gamma_grid)) -
matrix(sigma_joint^2, nrow = length(sigma_joint), ncol = ncol(mu_trt_obs_joint)) %o% gamma_grid[, 1]
atc_joint <- mu_trt_test_joint %o% rep(1, nrow(gamma_grid)) - mu_ctrl_obs_joint %o% rep(1, nrow(gamma_grid)) -
matrix(sigma_joint^2, nrow = length(sigma_joint), ncol = ncol(mu_trt_test_joint)) %o% gamma_grid[, 2]
att_mean_joint <- apply(att_joint, c(1, 3), mean)
atc_mean_joint <- apply(atc_joint, c(1, 3), mean)
ate_mean_joint <- q * att_mean_joint + (1-q) * atc_mean_joint
ate = apply(ate_mean_joint, 2, mean)
ate_contour = cbind(gamma_grid, ate)
contour_plot = ggplot() +
theme_bw() +
xlab(expression(gamma[0])) +
ylab(expression(gamma[1])) +
ggtitle("Contour Plot") +
theme(plot.title = element_text(hjust = 0.5)) +
stat_contour(data = ate_contour,
aes(x = gamma_0, y = gamma_1, z = ate, color = ..level..),
binwidth = 1)+
scale_color_continuous(name = "ATE")
directlabels::direct.label(contour_plot,"bottom.pieces")
}
if(!joint)
{
# Observed X and y in T=1 #
x_train_trt = x_trt
y_train_trt = y_trt
# Observed X and y in T=0 #
x_train_ctrl = x_ctrl
y_train_ctrl = y_ctrl
# X for missing outcomes #
x_test_trt = x_ctrl
x_test_ctrl = x_trt
# Fitting Bart for T=1 #
trt_bart_fit <- BART::wbart(as.matrix(x_train_trt), as.matrix(y_train_trt), as.matrix(x_test_trt))
# estimated result for T=1 #
mu_trt_obs <- trt_bart_fit$yhat.train ## 1000 draws * n_trt obs
mu_trt_test <- trt_bart_fit$yhat.test ## 1000 draws * n_ctrl obs
sig_trt_obs <- trt_bart_fit$sigma[101:1100]
# Fitting Bart in T=0 #
ctrl_bart_fit <- BART::wbart(as.matrix(x_train_ctrl), as.matrix(y_train_ctrl), as.matrix(x_test_ctrl))
# estimated results in T=0 #
mu_ctrl_obs <- ctrl_bart_fit$yhat.train ## 1000 draws * n_ctrl obs
mu_ctrl_test <- ctrl_bart_fit$yhat.test ## 1000 draws * n_trt obs
sig_ctrl_obs <- ctrl_bart_fit$sigma[101:1100]
# sample weight from posterior of f(T) #
q <- rbeta(1000, nrow(x_train_trt) + 1, nrow(x_train_ctrl) + 1)
# gamma.grid #
gamma_0 <- c(seq(from = -largest_effect, to = 0, length.out = (gamma_length + 1) / 2),
seq(from= 0, to = largest_effect, length.out = (gamma_length + 1) / 2)[-1])
gamma_1 <- gamma_0
gamma_grid <- expand.grid(gamma_0, gamma_1)
colnames(gamma_grid) = c("gamma_0", "gamma_1")
att <- mu_trt_obs %o% rep(1, nrow(gamma_grid)) - mu_ctrl_test %o% rep(1, nrow(gamma_grid)) -
matrix(sig_trt_obs^2, nrow = length(sig_trt_obs), ncol = ncol(mu_trt_obs)) %o% gamma_grid[, 1]
atc <- mu_trt_test %o% rep(1, nrow(gamma_grid)) -
matrix(sig_ctrl_obs^2, nrow = length(sig_ctrl_obs), ncol = ncol(mu_trt_test)) %o% gamma_grid[, 2] -
mu_ctrl_obs %o% rep(1, nrow(gamma_grid))
# E(Y(1)-Y(0)|T=1) #
att_mean <- apply(att, c(1, 3), mean)
# E(Y(1)-Y(0)|T=0) #
atc_mean <- apply(atc, c(1, 3), mean)
# average treatment effect E(Y(1)-Y(0)) #
ate_mean <- q * att_mean + (1-q) * atc_mean ## ndpost*(n_gamma0*n_gamma1)
ate = apply(ate_mean, 2, mean)
ate_contour = cbind(gamma_grid, ate)
contour_plot = ggplot() +
theme_bw() +
xlab(expression(gamma[0])) +
ylab(expression(gamma[1])) +
ggtitle("Contour Plot") +
theme(plot.title = element_text(hjust = 0.5)) +
stat_contour(data = ate_contour,
aes(x = gamma_0, y = gamma_1, z = ate, color = ..level..),
binwidth = 1)+
scale_color_continuous(name = "ATE")
directlabels::direct.label(contour_plot,"bottom.pieces")
}
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.