Nothing
#' Flexible Monte Carlo sensitivity analysis for unmeasured confounding
#'
#' The function \code{sa} implements the flexible sensitivity analysis
#' approach for unmeasured confounding with multiple treatments
#' and a binary outcome.
#'
#' @param y A numeric vector (0, 1) representing a binary outcome.
#' @param x A dataframe, including all the covariates but not treatments.
#' @param w A numeric vector representing the treatment groups.
#' @param formula A \code{\link[stats]{formula}} object for the analysis.
#' The default is to use all terms specified in \code{x}.
#' @param prior_c_function 1) A vector of characters indicating the
#' prior distributions for the confounding functions.
#' Each character contains the random number generation code
#' from the standard probability
#' \code{\link[stats:Distributions]{distributions}}
#' in the \code{\link[stats:stats-package]{stats}} package.
#' 2) A vector of characters including the grid specifications for
#' the confounding functions. It should be used when users want to formulate
#' the confounding functions as scalar values.
#' 3) A matrix indicating the point mass prior for the confounding functions
#' @param m1 A numeric value indicating the number of draws of the GPS
#' from the posterior predictive distribution
#' @param m2 A numeric value indicating the number of draws from
#' the prior distributions of the confounding functions
#' @param n_cores A numeric value indicating number of cores to use
#' for parallel computing.
#' @param estimand A character string representing the type of
#' causal estimand. Only \code{"ATT"} or \code{"ATE"} is allowed.
#' When the \code{estimand = "ATT"}, users also need to specify the
#' reference treatment group by setting the \code{reference_trt} argument.
#' @param reference_trt A numeric value indicating reference treatment group
#' for ATT effect.
#' @param ... Other parameters that can be passed to BART functions
#'
#' @return A list of causal estimands including risk difference (RD)
#' between different treatment groups.
#'
#' @export
#' @importFrom foreach %dopar% foreach
#' @importFrom stringr str_detect str_locate str_sub
#' @importFrom tidyr expand_grid
#' @importFrom BART mbart2 wbart pwbart
#' @importFrom parallel makeCluster stopCluster
#' @importFrom doParallel registerDoParallel
#' @references
#'
#' Hadley Wickham (2019).
#' \emph{stringr: Simple, Consistent Wrappers for Common String Operations}.
#' R package version 1.4.0.
#' URL:\url{https://CRAN.R-project.org/package=stringr}
#'
#' Hadley Wickham (2021).
#' \emph{tidyr: Tidy Messy Data}.
#' R package version 1.1.4.
#'URL:\url{https://CRAN.R-project.org/package=tidyr}
#'
#' Sparapani R, Spanbauer C, McCulloch R
#' Nonparametric Machine Learning and
#' Efficient Computation with Bayesian Additive Regression Trees:
#' The BART R Package. \emph{Journal of Statistical Software},
#' \strong{97}(1), 1-66.
#'
#' Microsoft Corporation and Steve Weston (2020).
#' \emph{doParallel: Foreach Parallel Adaptor for the 'parallel' Package}.
#' R package version 1.0.16.
#' URL:\url{https://CRAN.R-project.org/package=doParallel}
#'
#' Microsoft and Steve Weston (2020).
#' \emph{foreach: Provides Foreach Looping Construct.}.
#' R package version 1.5.1
#' URL:\url{https://CRAN.R-project.org/package=foreach}
#' @examples
#' \donttest{
#' lp_w_all <-
#' c(
#' ".4*x1 + .1*x2 - 1.1*x4 + 1.1*x5", # w = 1
#' ".2 * x1 + .2 * x2 - 1.2 * x4 - 1.3 * x5"
#' ) # w = 2
#' nlp_w_all <-
#' c(
#' "-.5*x1*x4 - .1*x2*x5", # w = 1
#' "-.3*x1*x4 + .2*x2*x5"
#' ) # w = 2
#' lp_y_all <- rep(".2*x1 + .3*x2 - .1*x3 - 1.1*x4 - 1.2*x5", 3)
#' nlp_y_all <- rep(".7*x1*x1 - .1*x2*x3", 3)
#' X_all <- c(
#' "rnorm(0, 0.5)", # x1
#' "rbeta(2, .4)", # x2
#' "runif(0, 0.5)", # x3
#' "rweibull(1,2)", # x4
#' "rbinom(1, .4)" # x5
#' )
#' set.seed(1111)
#' data <- data_sim(
#' sample_size = 100,
#' n_trt = 3,
#' x = X_all,
#' lp_y = lp_y_all,
#' nlp_y = nlp_y_all,
#' align = FALSE,
#' lp_w = lp_w_all,
#' nlp_w = nlp_w_all,
#' tau = c(0.5, -0.5, 0.5),
#' delta = c(0.5, 0.5),
#' psi = 2
#' )
#' c_grid <- c(
#' "runif(-0.6, 0)", # c(1,2)
#' "runif(0, 0.6)", # c(2,1)
#' "runif(-0.6, 0)", # c(2,3)
#' "seq(-0.6, 0, by = 0.3)", # c(1,3)
#' "seq(0, 0.6, by = 0.3)", # c(3,1)
#' "runif(0, 0.6)" # c(3,2)
#' )
#' sensitivity_analysis_parallel_result <-
#' sa(
#' m1 = 1,
#' x = data$covariates,
#' y = data$y,
#' w = data$w,
#' prior_c_function = c_grid,
#' n_cores = 1,
#' estimand = "ATE",
#' )
#' }
sa <-
function(x,
y,
w,
formula = NULL,
prior_c_function,
m1,
m2 = NULL,
n_cores = 1,
estimand,
reference_trt,
...) {
# First check the user's inputs
if (!(estimand %in% c("ATE", "ATT")))
stop("Estimand only supported for \"ATT\" or \"ATE\"", call. = FALSE)
if (estimand == "ATT" &&
!(reference_trt %in% unique(w)))
stop(paste0(
"Please set the reference_trt from ",
paste0(sort(unique(w)), collapse = ", "),
"."
),
call. = FALSE)
if (sum(c(
length(w) == length(y),
length(w) == nrow(x),
length(y) == nrow(x)
)) != 3)
stop(
paste0(
"The length of y, the length of w and the nrow for x should be equal.
Please double check the input."
),
call. = FALSE
)
if (!is.null(formula)) {
x <-
as.data.frame(stats::model.matrix(object = formula, cbind(y, x)))
x <- x[, !(names(x) == "(Intercept)")]
}
# When the confounding function is a full prior with uncertainty specified
# and without a range of point mass priors
if (any(stringr::str_detect(prior_c_function, "seq")) == FALSE &&
is.numeric(prior_c_function) == FALSE) {
prior_c_function_all <-
matrix(NA, ncol = length(prior_c_function), nrow = m2)
for (i in seq_len(length(prior_c_function))) {
str_locate_parenthesis <-
stringr::str_locate(prior_c_function[i], "\\(")
prior_c_function_all[, i] <-
eval(parse(text = paste0(
paste0(
stringr::str_sub(prior_c_function[i], 1,
str_locate_parenthesis[1]),
m2,
",",
stringr::str_sub(prior_c_function[i],
str_locate_parenthesis[1] + 1)
)
)))
}
prior_c_function_used <- prior_c_function_all
}
# When the confounding function involves a re-analysis
# over a range of point mass priors
if (any(stringr::str_detect(prior_c_function, "seq")) == TRUE) {
# First extract those involves a range of point mass priors
c_index_with_grid <-
which(stringr::str_detect(prior_c_function, "seq"))
c_index_without_grid <-
which(!stringr::str_detect(prior_c_function, "seq"))
n_c_with_grid <-
length(prior_c_function[stringr::str_detect(prior_c_function, "seq")])
grid_length <-
length(eval(parse(text = prior_c_function[
stringr::str_detect(prior_c_function, "seq")])))
m2 <- grid_length ^ n_c_with_grid
c_with_grid <- prior_c_function[c_index_with_grid]
c_without_grid <- prior_c_function[c_index_without_grid]
# Then handle the other confounding functions
c_without_grid_all <-
matrix(NA, ncol = length(c_without_grid), nrow = m2)
for (i in seq_len(length(c_without_grid))) {
str_locate_parenthesis <-
stringr::str_locate(c_without_grid[i], "\\(")
c_without_grid_all[, i] <-
eval(parse(text = paste0(
paste0(
stringr::str_sub(c_without_grid[i], 1,
str_locate_parenthesis[1]),
m2,
",",
stringr::str_sub(c_without_grid[i],
str_locate_parenthesis[1] + 1)
)
)))
}
colnames(c_without_grid_all) <- c_index_without_grid
c_with_grid_1 <- NULL
for (i in seq_len(length(c_index_with_grid))) {
assign(paste0("c_with_grid_", i), eval(parse(text = c_with_grid[i])))
}
c_with_grid_all <- c_with_grid_1
for (i in 1:(length(c_index_with_grid) - 1)) {
c_with_grid_all <-
tidyr::expand_grid(c_with_grid_all, eval(parse(text = paste0(
"c_with_grid_", (i + 1)
))))
}
# Combine the c functions with a range of point mass priors and without
colnames(c_with_grid_all) <- c_index_with_grid
c_functions_grid_final <-
cbind(as.data.frame(c_without_grid_all), c_with_grid_all)
names(c_functions_grid_final) <-
paste0("c", names(c_functions_grid_final))
c_functions_grid_final <- c_functions_grid_final %>%
select(paste0("c", seq_len(length(prior_c_function))))
prior_c_function_used <- c_functions_grid_final
}
if (is.numeric(prior_c_function) == TRUE) {
prior_c_function_used <- t(apply(prior_c_function, 2, mean))
}
# change the type of y and w as the input parameter of bart function
x <- as.matrix(x)
y <- as.numeric(y)
w <- as.integer(w)
n_trt <- length(unique(w))
prior_c_function_used <- as.matrix(prior_c_function_used)
n_alpha <- nrow(prior_c_function_used)
# fit the treatment assigment model, to use gap-sampling,
# we over sample n * 10 samples, and select a sample per 10 turns
a_model <-
BART::mbart2(
x.train = x,
as.integer(as.factor(w)),
x.test = x,
ndpost = m1 * 10,
mc.cores = n_cores
)
# assign the estimated assignment probability to each sample,
# the size is (n, #treatment, sample_size)
gps <-
array(a_model$prob.test[seq(1, nrow(a_model$prob.test), 10), ],
dim = c(m1, length(unique(w)), length(w)))
train_x <- cbind(x, w)
n_trt <- length(unique(w))
cl <- parallel::makeCluster(n_cores)
doParallel::registerDoParallel(cl)
# When the estimand is ATE
if (estimand == "ATE") {
# First set up the paramters used for parallel computing
out <-
foreach::foreach(
i = 1:n_alpha,
.combine = function(x, y) {
result_list_final <-
vector("list", length = (n_trt * (n_trt - 1) / 2))
counter <- 1
for (k in 1:(n_trt - 1)) {
for (m in (k + 1):n_trt) {
result_list_final[[counter]] <-
rbind(x[[paste0("ATE_", k, m)]], y[[paste0("ATE_", k, m)]])
names(result_list_final)[[counter]] <-
paste0("ATE_", k, m)
counter <- counter + 1
}
}
result_list_final
}
) %dopar% {
# Start parallel computing
cat("Starting ", i, "th job.\n", sep = "")
for (k in 1:(n_trt - 1)) {
for (m in (k + 1):n_trt) {
assign(paste0("ATE_", k, m), NULL)
}
}
for (j in 1:m1) {
# correct the binary outcome based on w, prior_c_function_used, gps
train_y <-
ifelse(
train_x[, "w"] == sort(unique(train_x[, "w"]))[1],
y - (
unlist(prior_c_function_used[i, 1]) * gps[j, 2, ] +
unlist(prior_c_function_used[i, 4]) * gps[j, 3, ]
),
ifelse(
train_x[, "w"] == sort(unique(train_x[, "w"]))[2],
y - (
unlist(prior_c_function_used[i, 2]) * gps[j, 1, ] +
unlist(prior_c_function_used[i, 3]) * gps[j, 3, ]
),
y - (
unlist(prior_c_function_used[i, 5]) * gps[j, 1, ] +
unlist(prior_c_function_used[i, 6]) * gps[j, 2, ]
)
)
)
# fit the bart model to estimate causal effect
bart_mod <-
BART::wbart(
x.train = cbind(x, w),
y.train = train_y,
printevery = 10000
)
n_trt <- length(unique(w))
for (k in 1:n_trt) {
assign(paste0("predict_", k),
BART::pwbart(cbind(x, w = k), bart_mod$treedraws))
}
# save the final ATE estimates
for (k in 1:(n_trt - 1)) {
for (m in (k + 1):n_trt) {
assign(paste0("ATE_", k, m), c(eval(parse(
text = paste0("ATE_", k, m)
)), rowMeans(eval(
parse(text = paste0("predict_", k))
) - eval(
parse(text = paste0("predict_", m))
))))
}
}
}
result_list <-
vector("list", length = (n_trt * (n_trt - 1) / 2))
# Add the names for ATE
counter <- 1
for (k in 1:(n_trt - 1)) {
for (m in (k + 1):n_trt) {
result_list[[counter]] <- eval(parse(text = paste0("ATE_", k, m)))
names(result_list)[[counter]] <- paste0("ATE_", k, m)
counter <- counter + 1
}
}
return(result_list)
}
parallel::stopCluster(cl)
result_list_final <-
vector("list", length = (n_trt * (n_trt - 1) / 2))
# Add the list names for the final result list
counter <- 1
for (k in 1:(n_trt - 1)) {
for (m in (k + 1):n_trt) {
result_list_final[[counter]] <- out[[counter]]
names(result_list_final)[[counter]] <- paste0("ATE_", k, m)
counter <- counter + 1
}
}
# When the confounding function involves a range of point mass priors
if (any(stringr::str_detect(prior_c_function, "seq")) == TRUE) {
result_final <- vector("list", length = (n_trt * (n_trt - 1) / 2))
counter <- 1
for (k in 1:(n_trt - 1)) {
for (m in (k + 1):n_trt) {
result_final[[counter]] <-
apply(result_list_final[[paste0("ATE_", k, m)]], 1, mean)
names(result_final)[[counter]] <- paste0("ATE", k, m)
counter <- counter + 1
}
}
result_final <-
c(
result_final,
list(c_functions = prior_c_function_used,
grid_index = c_index_with_grid)
)
class(result_final) <- "CIMTx_sa_grid"
return(result_final)
}
# When the confounding function do not involve
# a range of point mass priors
if (any(stringr::str_detect(prior_c_function, "seq")) == FALSE) {
result_final <- NULL
counter <- 1
for (k in 1:(n_trt - 1)) {
for (m in (k + 1):n_trt) {
result_final <-
c(result_final, list(result_list_final[[paste0("ATE_", k, m)]]))
names(result_final)[[counter]] <- paste0("ATE_RD", k, m)
counter <- counter + 1
}
}
class(result_final) <- "CIMTx_ATE_sa"
return(result_final)
}
}
if (estimand == "ATT") {
w_ind <- 1:n_trt
w_ind_no_reference <- w_ind[w_ind != reference_trt]
# First set up the paramters used for parallel computing
out <-
foreach::foreach(
i = 1:n_alpha,
.combine = function(x, y) {
result_list_final <- vector("list", length = (n_trt - 1))
counter <- 1
for (k in 1:(n_trt - 1)) {
result_list_final[[counter]] <-
rbind(x[[paste0("ATT_", reference_trt, w_ind_no_reference[k])]],
y[[paste0("ATT_", reference_trt, w_ind_no_reference[k])]])
names(result_list_final)[[counter]] <-
paste0("ATT_", reference_trt, w_ind_no_reference[k])
counter <- counter + 1
}
result_list_final
}
) %dopar% {
# Start parallel computing
cat("Starting ", i, "th job.\n", sep = "")
for (k in 1:(n_trt - 1)) {
assign(paste0("ATT_", reference_trt, w_ind_no_reference[k]),
NULL)
}
for (j in 1:m1) {
# correct the binary outcome based on w, prior_c_function, gps
train_y <-
ifelse(
train_x[, "w"] == sort(unique(train_x[, "w"]))[1],
y - (
unlist(prior_c_function[i, 1]) * gps[j, 2, ] +
unlist(prior_c_function[i, 4]) * gps[j, 3, ]
),
ifelse(
train_x[, "w"] == sort(unique(train_x[, "w"]))[2],
y - (
unlist(prior_c_function[i, 2]) * gps[j, 1, ] +
unlist(prior_c_function[i, 3]) * gps[j, 3, ]
),
y - (
unlist(prior_c_function[i, 5]) * gps[j, 1, ] +
unlist(prior_c_function[i, 6]) * gps[j, 2, ]
)
)
)
# fit the bart model to estimate causal effect
bart_mod <-
BART::wbart(
x.train = cbind(x, w),
y.train = train_y,
printevery = 10000,
...
)
n_trt <- length(unique(w))
for (k in 1:n_trt) {
assign(paste0("predict_", k),
BART::pwbart(cbind(x[w == reference_trt, ], w = k),
bart_mod$treedraws))
}
for (k in 1:(n_trt - 1)) {
# Save the final adjusted ATT effect
assign(paste0("ATT_", reference_trt, w_ind_no_reference[k]),
c(eval(parse(
text = paste0("ATT_", reference_trt,
w_ind_no_reference[k])
)), rowMeans(eval(
parse(text = paste0("predict_", reference_trt))
) - eval(
parse(text = paste0(
"predict_", w_ind_no_reference[k]
))
))))
}
}
result_list <-
vector("list", length = (n_trt * (n_trt - 1) / 2))
counter <- 1
for (k in 1:(n_trt - 1)) {
result_list[[counter]] <-
eval(parse(text = paste0(
"ATT_", reference_trt, w_ind_no_reference[k]
)))
names(result_list)[[counter]] <-
paste0("ATT_", reference_trt, w_ind_no_reference[k])
counter <- counter + 1
}
return(result_list)
}
parallel::stopCluster(cl)
result_list_final <- vector("list", length = (n_trt - 1))
counter <- 1
# Add the names of the ATT effect
for (k in 1:(n_trt - 1)) {
result_list_final[[counter]] <- out[[counter]]
names(result_list_final)[[counter]] <-
paste0("ATT_", reference_trt, w_ind_no_reference[k])
counter <- counter + 1
}
# When the confounding function involves a range of point mass priors
if (any(stringr::str_detect(prior_c_function, "seq")) == TRUE) {
result_final <- vector("list", length = (n_trt - 1))
counter <- 1
for (k in 1:(n_trt - 1)) {
result_final[[counter]] <-
apply(result_list_final[[paste0("ATT_", reference_trt,
w_ind_no_reference[k])]], 1, mean)
names(result_final)[[counter]] <-
paste0("ATT", reference_trt, w_ind_no_reference[k])
counter <- counter + 1
}
result_final <-
c(
result_final,
list(c_functions = prior_c_function_used,
grid_index = c_index_with_grid)
)
class(result_final) <- "CIMTx_sa_grid"
return(result_final)
}
# When the confounding function do not involve
# a range of point mass priors
if (any(stringr::str_detect(prior_c_function, "seq")) == FALSE) {
result_final <- NULL
counter <- 1
for (k in 1:(n_trt - 1)) {
result_final <-
c(result_final,
list(result_list_final[[paste0("ATT_",
reference_trt,
w_ind_no_reference[k])]]))
names(result_final)[[counter]] <-
paste0("ATT_RD", reference_trt, w_ind_no_reference[k])
counter <- counter + 1
}
}
class(result_final) <- "CIMTx_ATT_sa"
return(result_final)
}
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.