knitr::opts_chunk$set( collapse = TRUE, comment = "#>" )
This vignette simulates data using the scheme described by @Hill2013 with the additional of 1 categorical variable. It it implemented in the function simulate_hill_su_data()
:
# before running library(bartMachine), set memory options(java.parameters = "-Xmx2000m") # restart R to take effect # check memory allocated to Java VM options("java.parameters") suppressPackageStartupMessages({ library(bartMachine) library(tidytreatment) library(dplyr) library(tidybayes) library(ggplot2) }) sim <- simulate_su_hill_data(n = 200, treatment_linear = FALSE, omega = 0, add_categorical = TRUE, coef_categorical_treatment = c(0,0,1), coef_categorical_nontreatment = c(-1,0,-1) ) # non-treated vs treated counts: table(sim$data$z) dat <- sim$data dat$c1 <- as.integer(dat$c1) # a selection of data dat %>% select(y, z, c1, x1:x3) %>% head()
bartMachine
Run the model to be used to assess treatment effects. Here we will use bartMachine
, which is one implementation of Bayesian Additive Regression Trees in R
[@Kapelner2016]. The package can be found on CRAN.
# if you increase the number of cores, the memory needs to be increased, # this requires restarting R, setting the 'java.parameters' option then # loading the bartMachine package. set_bart_machine_num_cores(2) # set serialize = TRUE if using the fit over multiple sessions # The first bart model will be for the propensity score... # i.e. propensity for selection of treatment? # regress y ~ covariates var_select_bart <- bartMachine( X = select(dat,-y,-z), y = select(dat, y)[[1]], num_burn_in = 2000, num_iterations_after_burn_in = 5000, serialize = TRUE, verbose = FALSE ) # select most important vars from y ~ covariates model var_select <- bartMachine::var_selection_by_permute_cv(var_select_bart, k_folds = 5) # regress z ~ most important covariates to get propensity score prop_bart <- bartMachine( X = select(dat,var_select$important_vars_cv), y = as.factor(select(dat, z)[[1]]), num_burn_in = 2000, num_iterations_after_burn_in = 5000, serialize = TRUE, verbose = FALSE ) dat$prop_score <- prop_bart$p_hat_train destroy_bart_machine(var_select_bart) destroy_bart_machine(prop_bart) # Give z double prior inclusion probability prior_incl_prob <- setNames(rep(1, times = ncol(dat) - 1), colnames(dat)[colnames(dat) != "y"]) prior_incl_prob["z"] <- 2 bartM <- bartMachine( X = select(dat,-y), y = select(dat, y)[[1]], num_burn_in = 2000, num_iterations_after_burn_in = 5000, serialize = TRUE, verbose = FALSE, cov_prior_vec = prior_incl_prob )
Here are some examples of model checking we can do.
print(bartM) res <- residual_draws(bartM, include_newdata = FALSE) res %>% point_interval(.residual, y, .width = c(0.95) ) %>% select(-y.lower, -y.upper) %>% ggplot() + geom_pointinterval(aes(x = y, y = .residual, ymin = .residual.lower, ymax = .residual.upper), alpha = 0.2) + scale_fill_brewer() + theme_bw() + ggtitle("Residuals vs observations") res %>% summarise(.fitted = mean(.fitted), y = first(y)) %>% ggplot(aes(x = y, y = .fitted)) + geom_point() + geom_smooth(method = "lm") + theme_bw() + ggtitle("Observations vs fitted") res %>% summarise(.residual = mean(.residual)) %>% ggplot(aes(sample = .residual)) + geom_qq() + geom_qq_line() + theme_bw() + ggtitle("Q-Q plot of residuals") bartMachine:::plot_sigsqs_convergence_diagnostics(bartM) bartMachine:::plot_mh_acceptance_reject(bartM) bartMachine:::plot_tree_num_nodes(bartM) bartMachine:::plot_tree_depths(bartM)
Methods for extracting the posterior in a tidy format is included in the tidytreatment
.
posterior_fitted <- fitted_draws(bartM, value = "fit", include_newdata = FALSE) # The newdata argument (omitted) defaults to the data from the model. # include_newdata = FALSE, avoids returning the newdata with the fitted values # as it is so large. # The `.row` variable makes sure we know which row in the newdata the fitted # value came from (if we dont include the data in the result). posterior_fitted
# Function to tidy predicted draws also... posterior_pred <- predicted_draws(bartM, include_newdata = FALSE)
tidybayes
packageSince tidytreatment
follows the tidybayes
output specifications, functions from tidybayes
should work.
treatment_var_and_c1 <- dat %>% select(z,c1) %>% mutate(.row = 1:n(), z = as.factor(z)) posterior_fitted %>% left_join(treatment_var_and_c1, by = ".row") %>% ggplot() + geom_eye(aes(x = z, y = fit)) + facet_wrap(~c1, labeller = as_labeller( function(x) paste("c1 =",x) ) ) + xlab("Treatment (z)") + ylab("Posterior predicted value") + theme_bw() + ggtitle("Effect of treatment with 'c1' on posterior fitted values")
Posterior conditional (average) treatment effects can be calculated using the treatment_effects
function. This function finds the posterior values of
$$
\text{E}(y ~ \vert~ T = 1, X = x_{i}) - \text{E}(y ~ \vert~ T = 0, X = x_{i})
$$
for each unit of measurement, $i$, (e.g. subject) in the data sample.
Some histogram summaries are presented below.
# sample based (using data from fit) conditional treatment effects, posterior draws posterior_treat_eff <- treatment_effects(bartM, treatment = "z") posterior_treat_eff %>% ggplot() + geom_histogram(aes(x = cte), binwidth = 0.1, colour = "white") + theme_bw() + ggtitle("Histogram of treatment effect (all draws)") posterior_treat_eff %>% summarise(cte_hat = median(cte)) %>% ggplot() + geom_histogram(aes(x = cte_hat), binwidth = 0.1, colour = "white") + theme_bw() + ggtitle("Histogram of treatment effect (median for each subject)")
We can also focus on the treatment effects for just those that are treated.
# sample based (using data from fit) conditional treatment effects, posterior draws posterior_treat_eff_on_treated <- treatment_effects(bartM, treatment = "z", subset = "treated") posterior_treat_eff_on_treated %>% ggplot() + geom_histogram(aes(x = cte), binwidth = 0.1, colour = "white") + theme_bw() + ggtitle("Histogram of treatment effect (all draws from treated subjects)")
Plots can be made that stack each subjects posterior CIs of the CATEs.
posterior_treat_eff %>% select(-z) %>% point_interval() %>% arrange(cte) %>% mutate(.orow = 1:n()) %>% ggplot() + geom_interval(aes(x = .orow, y= cte), size = 0.5) + geom_point(aes(x = .orow, y = cte), shape = "circle open", alpha = 0.1) + ylab("Median posterior CATE for each subject (95% CI)") + theme_bw() + coord_flip() + scale_colour_brewer() + theme(axis.title.y = element_blank(), axis.text.y = element_blank(), axis.ticks.y = element_blank(), legend.position = "none")
We can also plot the CATEs varying over particular covariates. In this example, instead of grouping by subject, we group by the variable of interest, and calculate the posterior summaries over this variable.
posterior_treat_eff %>% left_join(dplyr::tibble(c1 = dat$c1, .row = 1:length(dat$c1) ), by = ".row") %>% group_by(c1) %>% ggplot() + geom_eye(aes(x = c1, y = cte), alpha = 0.2) + scale_fill_brewer() + theme_bw() + ggtitle("Treatment effect by `c1`")
Common support testing [@hill] can be tested directly, or a Boolean can be included when calculating the treatment effects.
csupp1 <- has_common_support(bartM, treatment = "z", method = "chisq", cutoff = 0.05) csupp1 %>% filter(!common_support) csupp2 <- has_common_support(bartM, treatment = "z", method = "sd", cutoff = 1) csupp2 %>% filter(!common_support) posterior_treat_eff_on_treated <- treatment_effects(bartM, treatment = "z", subset = "treated", common_support_method = "sd", cutoff = 1)
We can count how many times a variables was included in the BART in conjunction with the treatment effect, or overall.
treatment_interactions <- covariate_with_treatment_importance(bartM, treatment = "z") treatment_interactions %>% ggplot() + geom_bar(aes(x = variable, y = avg_inclusion), stat = "identity") + theme_bw() + ggtitle("Important variables interacting with treatment ('z')") + ylab("Inclusion counts") variable_importance <- covariate_importance(bartM) variable_importance %>% ggplot() + geom_bar(aes(x = variable, y = avg_inclusion), stat = "identity") + theme_bw() + ggtitle("Important variables overall") + ylab("Inclusion counts")
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.