# Estimate average treatment effect on the treated (ATT)
# The default reference group is 1st group
bart_multiTrt_att = function(y, x, trt, k=2, discard = FALSE, ntree=100, ndpost=1000, nskip=1000) {
n1 = sum(trt==1)
n2 = sum(trt==2)
n3 = sum(trt==3)
xt = cbind(trt,x)
# Fit BART
bart_mod = BART::pbart(x.train = xt, y.train = y, k = k, ntree = ntree, ndpost = ndpost, nskip = nskip)
# Predict potential outcomes for trt=1
xp1 = xt[trt==1,]
xp2 = xp1
xp3 = xp1
xp2[,1] = 2 # switch treatment label 1 to 2
xp3[,1] = 3 # switch treatment label 1 to 3
bart_pred11 = BART::pwbart(xp1, bart_mod$treedraws)
bart_pred12 = BART::pwbart(xp2, bart_mod$treedraws)
bart_pred13 = BART::pwbart(xp3, bart_mod$treedraws)
pred_prop11 = pnorm(bart_pred11)
pred_prop12 = pnorm(bart_pred12)
pred_prop13 = pnorm(bart_pred13)
if (discard == TRUE) {
#****************************#
# discarding rule #
#****************************#
#posterior standard deviation of the predicted outcome among those treated with W=1
post_ind_sd1 = apply(pred_prop11, 2, sd)
#threshold = max(post_ind_sd1) + sd(post_ind_sd1)
threshold = max(post_ind_sd1)
#discard unit i with W_i =1 if posterior sd of his/her counterfactual outcomes exceeeds threshold
post_ind_sd2 = apply(pred_prop12, 2, sd)
post_ind_sd3 = apply(pred_prop13, 2, sd)
eligible = post_ind_sd2 <= threshold & post_ind_sd3 <= threshold
n_discard_att = sum(eligible == F)
#these units are within the common support region
pred_prop11 = pred_prop11[,eligible]
pred_prop12 = pred_prop12[,eligible]
pred_prop13 = pred_prop13[,eligible]
#************************************#
}
# Estimate causal effects
RD12_est = RR12_est = OR12_est = NULL
RD13_est = RR13_est = OR13_est = NULL
for (m in 1:ndpost) {
# Estimate E(Y1|trt=1), E(Y2|trt=1), E(Y3|trt=1)
y1_pred = mean(rbinom(n1, 1, pred_prop11))
y2_pred = mean(rbinom(n1, 1, pred_prop12))
y3_pred = mean(rbinom(n1, 1, pred_prop13))
# Calculate risk difference (RD)
RD12_est[m] = y1_pred - y2_pred
RD13_est[m] = y1_pred - y3_pred
# Calculate relative risk (RR)
RR12_est[m] = y1_pred / y2_pred
RR13_est[m] = y1_pred / y3_pred
# Calculate odds ratio (OR)
OR12_est[m] = (y1_pred / (1 - y1_pred)) / (y2_pred / (1 - y2_pred))
OR13_est[m] = (y1_pred / (1 - y1_pred)) / (y3_pred / (1 - y3_pred))
}
att12 = postSumm(RD12_est, RR12_est, OR12_est)
att13 = postSumm(RD13_est, RR13_est, OR13_est)
list(ATT12 = round(att12, digits=3),
ATT13 = round(att13, digits=3))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.