Nothing
## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.dim = c(6, 4)
)
suppressPackageStartupMessages({
library(BART)
library(tidytreatment)
library(dplyr)
library(tidybayes)
library(ggplot2)
})
# load pre-computed data and model
sim <- suhillsim1
te_model <- bartmodel1
# pre compute
posterior_treat_eff <- treatment_effects(te_model, treatment = "z", newdata = sim$data)
posterior_treat_eff_on_treated <- treatment_effects(te_model, treatment = "z", newdata = sim$dat, subset = "treated")
## ----load-data-print, echo = TRUE, eval = FALSE-------------------------------
#
# # load packages
# library(BART)
# library(tidytreatment)
# library(dplyr)
# library(tidybayes)
# library(ggplot2)
#
# # set seed so vignette is reproducible
# set.seed(101)
#
# # simulate data
# sim <- simulate_su_hill_data(n = 100, treatment_linear = FALSE, omega = 0, add_categorical = TRUE,
# coef_categorical_treatment = c(0,0,1),
# coef_categorical_nontreatment = c(-1,0,-1)
# )
#
## ----data-summary, echo = TRUE, eval = TRUE-----------------------------------
# non-treated vs treated counts:
table(sim$data$z)
dat <- sim$data
# a selection of data
dat %>% select(y, z, c1, x1:x3) %>% head()
## ----run-bart, echo = TRUE, eval = FALSE--------------------------------------
#
# # STEP 1 VS Model: Regress y ~ covariates
# var_select_bart <- wbart(x.train = select(dat,-y,-z),
# y.train = pull(dat, y),
# sparse = TRUE,
# nskip = 2000,
# ndpost = 5000)
#
# # STEP 2: Variable selection
# # select most important vars from y ~ covariates model
# # very simple selection mechanism. Should use cross-validation in practice
# covar_ranking <- covariate_importance(var_select_bart)
# var_select <- covar_ranking %>%
# filter(avg_inclusion >= quantile(avg_inclusion, 0.5)) %>%
# pull(variable)
#
# # change categorical variables to just one variable
# var_select <- unique(gsub("c1[1-3]$","c1", var_select))
#
# var_select
#
# # STEP 3 PS Model: Regress z ~ selected covariates
# # BART::pbart is for probit regression
# prop_bart <- pbart(
# x.train = select(dat, all_of(var_select)),
# y.train = pull(dat, z),
# nskip = 2000,
# ndpost = 5000
# )
#
# # store propensity score in data
# dat$prop_score <- prop_bart$prob.train.mean
#
# # Step 4 TE Model: Regress y ~ z + covariates + propensity score
# te_model <- wbart(
# x.train = select(dat,-y),
# y.train = pull(dat, y),
# nskip = 10000L,
# ndpost = 200L, #*
# keepevery = 100L #*
# )
#
# #* The posterior samples are kept small to manage size on CRAN
#
## ----tidy-bart-fit, echo=TRUE, cache=FALSE------------------------------------
posterior_fitted <- fitted_draws(te_model, value = "fit", include_newdata = FALSE)
# include_newdata = FALSE, avoids returning the newdata with the fitted values
# as it is so large. newdata argument must be specified for this option in BART models.
# The `.row` variable makes sure we know which row in the newdata the fitted
# value came from (if we dont include the data in the result).
posterior_fitted
## ----tidy-bart-pred, eval=FALSE, echo=TRUE, cache=FALSE-----------------------
#
# # Function to tidy predicted draws also, this adds random normal noise by default
# posterior_pred <- predicted_draws(te_model, include_newdata = FALSE)
#
## ----plot-tidy-bart, echo=TRUE, cache=FALSE-----------------------------------
treatment_var_and_c1 <-
dat %>%
select(z,c1) %>%
mutate(.row = 1:n(), z = as.factor(z))
posterior_fitted %>%
left_join(treatment_var_and_c1, by = ".row") %>%
ggplot() +
stat_halfeye(aes(x = z, y = fit)) +
facet_wrap(~c1, labeller = as_labeller( function(x) paste("c1 =",x) ) ) +
xlab("Treatment (z)") + ylab("Posterior predicted value") +
theme_bw() + ggtitle("Effect of treatment with 'c1' on posterior fitted values")
## ----post-treatment, eval = FALSE---------------------------------------------
#
# # sample based (using data from fit) conditional treatment effects, posterior draws
# posterior_treat_eff <-
# treatment_effects(te_model, treatment = "z", newdata = dat)
#
## ----cates-hist, echo=TRUE, cache=FALSE---------------------------------------
# Histogram of treatment effect (all draws)
posterior_treat_eff %>%
ggplot() +
geom_histogram(aes(x = cte), binwidth = 0.1, colour = "white") +
theme_bw() + ggtitle("Histogram of treatment effect (all draws)")
# Histogram of treatment effect (median for each subject)
posterior_treat_eff %>% summarise(cte_hat = median(cte)) %>%
ggplot() +
geom_histogram(aes(x = cte_hat), binwidth = 0.1, colour = "white") +
theme_bw() + ggtitle("Histogram of treatment effect (median for each subject)")
## ----att-ate, eval=FALSE------------------------------------------------------
# # get the ATE and ATT directly:
#
# posterior_ate <- tidy_ate(te_model, treatment = "z", newdata = dat)
# posterior_att <- tidy_att(te_model, treatment = "z", newdata = dat)
#
## ----ate-trace-setup, eval = TRUE, echo = FALSE-------------------------------
posterior_ate <- posterior_treat_eff %>% group_by(.chain, .iteration, .draw) %>%
summarise(ate = mean(cte), .groups = "drop")
## ----ate-trace, eval=TRUE, echo=TRUE------------------------------------------
posterior_ate %>% ggplot(aes(x = .draw, y = ate)) +
geom_line() +
theme_bw() +
ggtitle("Trace plot of ATE")
## ----post-te-treated, echo=TRUE, eval=FALSE-----------------------------------
#
# # sample based (using data from fit) conditional treatment effects, posterior draws
# posterior_treat_eff_on_treated <-
# treatment_effects(te_model, treatment = "z", newdata = dat, subset = "treated")
#
## ----cates-hist-treated, echo=TRUE, cache=FALSE-------------------------------
posterior_treat_eff_on_treated %>%
ggplot() +
geom_histogram(aes(x = cte), binwidth = 0.1, colour = "white") +
theme_bw() + ggtitle("Histogram of treatment effect (all draws from treated subjects)")
## ----cates-stack-plot, echo=TRUE, cache=FALSE---------------------------------
posterior_treat_eff %>% select(-z) %>% point_interval() %>%
arrange(cte) %>% mutate(.orow = 1:n()) %>%
ggplot() +
geom_interval(aes(x = .orow, y= cte, ymin = .lower, ymax = .upper)) +
geom_point(aes(x = .orow, y = cte), shape = "circle open", alpha = 0.5) +
ylab("Median posterior CATE for each subject (95% CI)") +
theme_bw() + coord_flip() + scale_colour_brewer() +
theme(axis.title.y = element_blank(),
axis.text.y = element_blank(),
axis.ticks.y = element_blank(),
legend.position = "none")
## ----cates-line-plot, echo=TRUE, cache=FALSE----------------------------------
posterior_treat_eff %>%
left_join(tibble(c1 = dat$c1, .row = 1:length(dat$c1) ), by = ".row") %>%
group_by(c1) %>%
ggplot() +
stat_halfeye(aes(x = c1, y = cte), alpha = 0.7) +
scale_fill_brewer() +
theme_bw() + ggtitle("Treatment effect by `c1`")
## ----common-support, echo=TRUE, results='hide', cache=FALSE-------------------
# calculate common support directly
# argument 'modeldata' must be specified for BART models
csupp_chisq <- has_common_support(te_model, treatment = "z", modeldata = dat,
method = "chisq", cutoff = 0.05)
csupp_chisq %>% filter(!common_support)
csupp_sd <- has_common_support(te_model, treatment = "z", modeldata = dat,
method = "sd", cutoff = 1)
csupp_sd %>% filter(!common_support)
# calculate treatment effects (on those who were treated)
# and include only those estimates with common support
posterior_treat_eff_on_treated <-
treatment_effects(te_model, treatment = "z", subset = "treated", newdata = dat,
common_support_method = "sd", cutoff = 1)
## ----interaction-investigator, echo=TRUE, cache=FALSE-------------------------
treatment_interactions <-
covariate_with_treatment_importance(te_model, treatment = "z")
treatment_interactions %>%
ggplot() +
geom_bar(aes(x = variable, y = avg_inclusion), stat = "identity") +
theme_bw() + ggtitle("Important variables interacting with treatment ('z')") + ylab("Inclusion counts") +
theme(axis.text.x = element_text(angle = 45, hjust=1))
variable_importance <-
covariate_importance(te_model)
variable_importance %>%
ggplot() +
geom_bar(aes(x = variable, y = avg_inclusion), stat = "identity") +
theme_bw() + ggtitle("Important variables overall") +
ylab("Inclusion counts") +
theme(axis.text.x = element_text(angle = 45, hjust=1))
## ----sigma-trace, echo=TRUE, cache=FALSE--------------------------------------
# includes skipped MCMC samples
variance_draws(te_model, value = "siqsq") %>%
filter(.draw > 10000) %>%
ggplot(aes(x = .draw, y = siqsq)) +
geom_line() +
theme_bw() +
ggtitle("Trace plot of model variance post warm-up")
## ----convergence-bart, echo=TRUE, cache=FALSE---------------------------------
res <- residual_draws(te_model, response = pull(dat, y), include_newdata = FALSE)
res %>%
point_interval(.residual, y, .width = c(0.95) ) %>%
select(-y.lower, -y.upper) %>%
ggplot() +
geom_pointinterval(aes(x = y, y = .residual, ymin = .residual.lower, ymax = .residual.upper), alpha = 0.2) +
scale_fill_brewer() +
theme_bw() + ggtitle("Residuals vs observations")
res %>% summarise(.fitted = mean(.fitted), y = first(y)) %>%
ggplot(aes(x = y, y = .fitted)) +
geom_point() +
geom_smooth(method = "lm") +
theme_bw() + ggtitle("Observations vs fitted")
res %>% summarise(.residual = mean(.residual)) %>%
ggplot(aes(sample = .residual)) +
geom_qq() +
geom_qq_line() +
theme_bw() + ggtitle("Q-Q plot of residuals")
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.