R/bart_multiTrt_ate.R

Defines functions bart_multiTrt_ate

#######################################################################
# Functions to estimate causal effects of multiple treatment using BART
# Assume 3 treatment
#######################################################################


# Estimate average treatment effect (ATE)
bart_multiTrt_ate = function(y, x, trt, k=2, discard = FALSE, ntree=100, ndpost=1000, nskip=1000) {

    n1 = sum(trt==1)
    n2 = sum(trt==2)
    n3 = sum(trt==3)

    xt = cbind(trt,x)


    # Fit BART
    bart_mod = BART::pbart(x.train = xt, y.train = y, k = k, ntree = ntree, ndpost = ndpost, nskip = nskip)


    # Predict potential outcomes for trt=1
    xp1 = xt[trt==1,]
    xp2 = xp1
    xp3 = xp1
    xp2[,1] = 2  # switch treatment label 1 to 2
    xp3[,1] = 3  # switch treatment label 1 to 3

    bart_pred11 = BART::pwbart(xp1, bart_mod$treedraws)
    bart_pred12 = BART::pwbart(xp2, bart_mod$treedraws)
    bart_pred13 = BART::pwbart(xp3, bart_mod$treedraws)

    pred_prop11 = pnorm(bart_pred11)
    pred_prop12 = pnorm(bart_pred12)
    pred_prop13 = pnorm(bart_pred13)


    # Predict potential outcomes for trt=2
    xp2 = xt[trt==2,]
    xp1 = xp2
    xp3 = xp2
    xp1[,1] = 1  # switch treatment label 2 to 1
    xp3[,1] = 3  # switch treatment label 2 to 3

    bart_pred21 = BART::pwbart(xp1, bart_mod$treedraws)
    bart_pred22 = BART::pwbart(xp2, bart_mod$treedraws)
    bart_pred23 = BART::pwbart(xp3, bart_mod$treedraws)

    pred_prop21 = pnorm(bart_pred21)
    pred_prop22 = pnorm(bart_pred22)
    pred_prop23 = pnorm(bart_pred23)

    # Predict potential outcomes for trt=3
    xp3 = xt[trt==3,]
    xp1 = xp3
    xp2 = xp3
    xp1[,1] = 1  # switch treatment label 3 to 1
    xp2[,1] = 2  # switch treatment label 3 to 2

    bart_pred31 = BART::pwbart(xp1, bart_mod$treedraws)
    bart_pred32 = BART::pwbart(xp2, bart_mod$treedraws)
    bart_pred33 = BART::pwbart(xp3, bart_mod$treedraws)

    pred_prop31 = pnorm(bart_pred31)
    pred_prop32 = pnorm(bart_pred32)
    pred_prop33 = pnorm(bart_pred33)

    if (discard == TRUE) {
      #****************************#
      #  BART discarding rule      #
      #****************************#
      #posterior standard deviation of the predicted outcome among those treated with W=1
      post_ind_sd11 = apply(pred_prop11, 2, sd)
      post_ind_sd12 = apply(pred_prop12, 2, sd)
      post_ind_sd13 = apply(pred_prop13, 2, sd)
      #discard unit i with W_i =1 if posterior sd of his/her counterfactual outcomes exceeeds threshold
      #threshold1 = max(post_ind_sd11) + sd(post_ind_sd11)
      threshold1 = max(post_ind_sd11)

      post_ind_sd21 = apply(pred_prop21, 2, sd)
      post_ind_sd22 = apply(pred_prop22, 2, sd)
      post_ind_sd23 = apply(pred_prop23, 2, sd)
      #threshold2 = max(post_ind_sd22) + sd(post_ind_sd22)
      threshold2 = max(post_ind_sd22)

      post_ind_sd31 = apply(pred_prop31, 2, sd)
      post_ind_sd32 = apply(pred_prop32, 2, sd)
      post_ind_sd33 = apply(pred_prop33, 2, sd)
      #threshold3 = max(post_ind_sd33) + sd(post_ind_sd33)
      threshold3 = max(post_ind_sd33)

      eligible1 = post_ind_sd12 <= threshold1 & post_ind_sd13 <= threshold1
      eligible2 = post_ind_sd21 <= threshold2 & post_ind_sd23 <= threshold2
      eligible3 = post_ind_sd31 <= threshold3 & post_ind_sd32 <= threshold3
      n_1_discard <- sum(eligible1 == F)
      n_2_discard <- sum(eligible2 == F)
      n_3_discard <- sum(eligible3 == F)

      #these units are within the common support region
      pred_prop11 = pred_prop11[,eligible1]
      pred_prop12 = pred_prop12[,eligible1]
      pred_prop13 = pred_prop13[,eligible1]

      pred_prop21 = pred_prop21[,eligible2]
      pred_prop22 = pred_prop22[,eligible2]
      pred_prop23 = pred_prop23[,eligible2]

      pred_prop31 = pred_prop31[,eligible3]
      pred_prop32 = pred_prop32[,eligible3]
      pred_prop33 = pred_prop33[,eligible3]
      #************************************#

    }

    # Estimate causal effects
    RD12_est = RR12_est = OR12_est = NULL
    RD13_est = RR13_est = OR13_est = NULL
    RD23_est = RR23_est = OR23_est = NULL

    for (m in 1:ndpost) {

        # Estimate E(Y1), E(Y2), E(Y3)
        y1 = c(rbinom(n1, 1, pred_prop11[m,]), rbinom(n2, 1, pred_prop21[m,]), rbinom(n3, 1, pred_prop31[m,]))
        y2 = c(rbinom(n1, 1, pred_prop12[m,]), rbinom(n2, 1, pred_prop22[m,]), rbinom(n3, 1, pred_prop32[m,]))
        y3 = c(rbinom(n1, 1, pred_prop13[m,]), rbinom(n2, 1, pred_prop23[m,]), rbinom(n3, 1, pred_prop33[m,]))

        y1_pred = mean(y1)
        y2_pred = mean(y2)
        y3_pred = mean(y3)

        # Calculate risk difference (RD)
        RD12_est[m] = y1_pred - y2_pred
        RD13_est[m] = y1_pred - y3_pred
        RD23_est[m] = y2_pred - y3_pred

        # Calculate relative risk (RR)
        RR12_est[m] = y1_pred / y2_pred
        RR13_est[m] = y1_pred / y3_pred
        RR23_est[m] = y2_pred / y3_pred

        # Calculate  odds ratio (OR)
        OR12_est[m] = (y1_pred / (1 - y1_pred)) / (y2_pred / (1 - y2_pred))
        OR13_est[m] = (y1_pred / (1 - y1_pred)) / (y3_pred / (1 - y3_pred))
        OR23_est[m] = (y2_pred / (1 - y2_pred)) / (y3_pred / (1 - y3_pred))
    }

    ate12 = postSumm(RD12_est, RR12_est, OR12_est)
    ate13 = postSumm(RD13_est, RR13_est, OR13_est)
    ate23 = postSumm(RD23_est, RR23_est, OR23_est)

    list(ATE12 = round(ate12, digits=3),
         ATE13 = round(ate13, digits=3),
         ATE23 = round(ate23, digits=3))
}
JiayiJi/causal.multiple.treatments documentation built on Nov. 14, 2019, 7:46 p.m.