inst/doc/use-tidytreatment-BART.R

## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.dim = c(6, 4)
)

suppressPackageStartupMessages({
    library(BART)
    library(tidytreatment)
    library(dplyr)
    library(tidybayes)
    library(ggplot2)
  })
  
  # load pre-computed data and model
  sim <- suhillsim1
  te_model <- bartmodel1
  
  # pre compute
  posterior_treat_eff <- treatment_effects(te_model, treatment = "z", newdata = sim$data) 
  posterior_treat_eff_on_treated <- treatment_effects(te_model, treatment = "z", newdata = sim$dat, subset = "treated") 
  

## ----load-data-print, echo = TRUE, eval = FALSE-------------------------------
#  
#  # load packages
#  library(BART)
#  library(tidytreatment)
#  library(dplyr)
#  library(tidybayes)
#  library(ggplot2)
#  
#  # set seed so vignette is reproducible
#  set.seed(101)
#  
#  # simulate data
#  sim <- simulate_su_hill_data(n = 100, treatment_linear = FALSE,  omega = 0, add_categorical = TRUE,
#                               coef_categorical_treatment = c(0,0,1),
#                               coef_categorical_nontreatment = c(-1,0,-1)
#  )
#  

## ----data-summary, echo = TRUE, eval = TRUE-----------------------------------

# non-treated vs treated counts:
table(sim$data$z)

dat <- sim$data
# a selection of data
dat %>% select(y, z, c1, x1:x3) %>% head()


## ----run-bart, echo = TRUE, eval = FALSE--------------------------------------
#  
#  # STEP 1 VS Model: Regress y ~ covariates
#  var_select_bart <- wbart(x.train = select(dat,-y,-z),
#                           y.train = pull(dat, y),
#                           sparse = TRUE,
#                           nskip = 2000,
#                           ndpost = 5000)
#  
#  # STEP 2: Variable selection
#    # select most important vars from y ~ covariates model
#    # very simple selection mechanism. Should use cross-validation in practice
#  covar_ranking <- covariate_importance(var_select_bart)
#  var_select <- covar_ranking %>%
#    filter(avg_inclusion >= quantile(avg_inclusion, 0.5)) %>%
#    pull(variable)
#  
#  # change categorical variables to just one variable
#  var_select <- unique(gsub("c1[1-3]$","c1", var_select))
#  
#  var_select
#  
#  # STEP 3 PS Model: Regress z ~ selected covariates
#    # BART::pbart is for probit regression
#  prop_bart <- pbart(
#    x.train = select(dat, all_of(var_select)),
#    y.train = pull(dat, z),
#    nskip = 2000,
#    ndpost = 5000
#  )
#  
#  # store propensity score in data
#  dat$prop_score <-  prop_bart$prob.train.mean
#  
#  # Step 4 TE Model: Regress y ~ z + covariates + propensity score
#  te_model <- wbart(
#    x.train = select(dat,-y),
#    y.train = pull(dat, y),
#    nskip = 10000L,
#    ndpost = 200L, #*
#    keepevery = 100L #*
#  )
#  
#  #* The posterior samples are kept small to manage size on CRAN
#  

## ----tidy-bart-fit, echo=TRUE, cache=FALSE------------------------------------

posterior_fitted <- fitted_draws(te_model, value = "fit", include_newdata = FALSE)
# include_newdata = FALSE, avoids returning the newdata with the fitted values
# as it is so large. newdata argument must be specified for this option in BART models. 
# The `.row` variable makes sure we know which row in the newdata the fitted
# value came from (if we dont include the data in the result).

posterior_fitted


## ----tidy-bart-pred, eval=FALSE, echo=TRUE, cache=FALSE-----------------------
#  
#  # Function to tidy predicted draws also, this adds random normal noise by default
#  posterior_pred <- predicted_draws(te_model, include_newdata = FALSE)
#  

## ----plot-tidy-bart, echo=TRUE, cache=FALSE-----------------------------------

treatment_var_and_c1 <- 
  dat %>% 
  select(z,c1) %>%
  mutate(.row = 1:n(), z = as.factor(z))

posterior_fitted %>%
  left_join(treatment_var_and_c1, by = ".row") %>%
  ggplot() + 
  stat_halfeye(aes(x = z, y = fit)) + 
  facet_wrap(~c1, labeller = as_labeller( function(x) paste("c1 =",x) ) ) +
  xlab("Treatment (z)") + ylab("Posterior predicted value") +
  theme_bw() + ggtitle("Effect of treatment with 'c1' on posterior fitted values")


## ----post-treatment, eval = FALSE---------------------------------------------
#  
#  # sample based (using data from fit) conditional treatment effects, posterior draws
#  posterior_treat_eff <-
#    treatment_effects(te_model, treatment = "z", newdata = dat)
#  

## ----cates-hist, echo=TRUE, cache=FALSE---------------------------------------

# Histogram of treatment effect (all draws)
posterior_treat_eff %>% 
  ggplot() +
  geom_histogram(aes(x = cte), binwidth = 0.1, colour = "white") + 
  theme_bw() + ggtitle("Histogram of treatment effect (all draws)")

# Histogram of treatment effect (median for each subject)
posterior_treat_eff %>% summarise(cte_hat = median(cte)) %>%
  ggplot() +
  geom_histogram(aes(x = cte_hat), binwidth = 0.1, colour = "white") + 
  theme_bw() + ggtitle("Histogram of treatment effect (median for each subject)")


## ----att-ate, eval=FALSE------------------------------------------------------
#  # get the ATE and ATT directly:
#  
#  posterior_ate <- tidy_ate(te_model, treatment = "z", newdata = dat)
#  posterior_att <- tidy_att(te_model, treatment = "z", newdata = dat)
#  

## ----ate-trace-setup, eval = TRUE, echo = FALSE-------------------------------

posterior_ate <- posterior_treat_eff %>% group_by(.chain, .iteration, .draw) %>%
  summarise(ate = mean(cte), .groups = "drop")


## ----ate-trace, eval=TRUE, echo=TRUE------------------------------------------

posterior_ate %>% ggplot(aes(x = .draw, y = ate)) +
  geom_line() +
  theme_bw() + 
  ggtitle("Trace plot of ATE")


## ----post-te-treated, echo=TRUE, eval=FALSE-----------------------------------
#  
#  # sample based (using data from fit) conditional treatment effects, posterior draws
#  posterior_treat_eff_on_treated <-
#    treatment_effects(te_model, treatment = "z", newdata = dat, subset = "treated")
#  

## ----cates-hist-treated, echo=TRUE, cache=FALSE-------------------------------

posterior_treat_eff_on_treated %>% 
  ggplot() +
  geom_histogram(aes(x = cte), binwidth = 0.1, colour = "white") + 
  theme_bw() + ggtitle("Histogram of treatment effect (all draws from treated subjects)")


## ----cates-stack-plot, echo=TRUE, cache=FALSE---------------------------------

posterior_treat_eff %>% select(-z) %>% point_interval() %>%
  arrange(cte) %>% mutate(.orow = 1:n()) %>% 
  ggplot() + 
  geom_interval(aes(x = .orow, y= cte, ymin = .lower, ymax = .upper)) +
  geom_point(aes(x = .orow, y = cte), shape = "circle open", alpha = 0.5) + 
  ylab("Median posterior CATE for each subject (95% CI)") +
  theme_bw() + coord_flip() + scale_colour_brewer() +
  theme(axis.title.y = element_blank(), 
        axis.text.y = element_blank(), 
        axis.ticks.y = element_blank(),
        legend.position = "none")


## ----cates-line-plot, echo=TRUE, cache=FALSE----------------------------------

posterior_treat_eff %>%
  left_join(tibble(c1 = dat$c1, .row = 1:length(dat$c1) ), by = ".row") %>%
  group_by(c1) %>%
  ggplot() + 
  stat_halfeye(aes(x = c1, y = cte), alpha = 0.7) +
  scale_fill_brewer() +
  theme_bw() + ggtitle("Treatment effect by `c1`")



## ----common-support, echo=TRUE, results='hide', cache=FALSE-------------------

# calculate common support directly
# argument 'modeldata' must be specified for BART models 
csupp_chisq <- has_common_support(te_model, treatment = "z", modeldata = dat,
                             method = "chisq", cutoff = 0.05)

csupp_chisq %>% filter(!common_support)

csupp_sd <- has_common_support(te_model, treatment = "z", modeldata = dat,
                             method = "sd", cutoff = 1)
csupp_sd %>% filter(!common_support)

# calculate treatment effects (on those who were treated) 
# and include only those estimates with common support
posterior_treat_eff_on_treated <- 
  treatment_effects(te_model, treatment = "z", subset = "treated", newdata = dat,
                    common_support_method = "sd", cutoff = 1) 


## ----interaction-investigator, echo=TRUE, cache=FALSE-------------------------

  treatment_interactions <-
    covariate_with_treatment_importance(te_model, treatment = "z")

  treatment_interactions %>% 
    ggplot() + 
    geom_bar(aes(x = variable, y = avg_inclusion), stat = "identity") +
    theme_bw() + ggtitle("Important variables interacting with treatment ('z')") + ylab("Inclusion counts") +
  theme(axis.text.x = element_text(angle = 45, hjust=1))
  
  variable_importance <-
    covariate_importance(te_model)

  variable_importance %>% 
    ggplot() + 
    geom_bar(aes(x = variable, y = avg_inclusion), stat = "identity") +
    theme_bw() + ggtitle("Important variables overall") +
    ylab("Inclusion counts") +
    theme(axis.text.x = element_text(angle = 45, hjust=1))
  

## ----sigma-trace, echo=TRUE, cache=FALSE--------------------------------------

# includes skipped MCMC samples
variance_draws(te_model, value = "siqsq") %>% 
  filter(.draw > 10000) %>%
  ggplot(aes(x = .draw, y = siqsq)) +
  geom_line() +
  theme_bw() + 
  ggtitle("Trace plot of model variance post warm-up")


## ----convergence-bart, echo=TRUE, cache=FALSE---------------------------------

res <- residual_draws(te_model, response = pull(dat, y), include_newdata = FALSE)
res %>%   
  point_interval(.residual, y, .width = c(0.95) ) %>%
  select(-y.lower, -y.upper) %>%
  ggplot() + 
  geom_pointinterval(aes(x = y, y = .residual, ymin = .residual.lower,  ymax = .residual.upper), alpha = 0.2) +
  scale_fill_brewer() +
  theme_bw() + ggtitle("Residuals vs observations")

res %>% summarise(.fitted = mean(.fitted), y = first(y)) %>% 
  ggplot(aes(x = y, y = .fitted)) +
  geom_point() + 
  geom_smooth(method = "lm") + 
  theme_bw() + ggtitle("Observations vs fitted")

res %>% summarise(.residual = mean(.residual)) %>%
  ggplot(aes(sample = .residual)) + 
  geom_qq() + 
  geom_qq_line() + 
  theme_bw() + ggtitle("Q-Q plot of residuals")

Try the tidytreatment package in your browser

Any scripts or data that you put into this service are public.

tidytreatment documentation built on March 18, 2022, 6:30 p.m.