R/mediate_b.R

Defines functions mediate_b

Documented in mediate_b

#' Mediation using Bayesian methods
#' 
#' 
#' Mediation analysis done in the framework of Imai et al. (2010).  Currently 
#' only applicable to linear models.
#' 
#' @details
#' The model is the same as that of Imai et al. (2010):
#' \deqn{
#'  M_i(X) =  w_i'\alpha_m + X\beta_m + \epsilon_{m,i}, \\
#'  y_i(X, M(\tilde X)) = w_i'\alpha_y + X\beta_y + M(\tilde X)\gamma + \epsilon_{y,i}, \\
#'  \epsilon_{m,i} \overset{iid}{\sim} N(0,\sigma^2_m), \\
#'  \epsilon_{y,i} \overset{iid}{\sim} N(0,\sigma^2_y), \\
#' }
#' where \eqn{M_i(X)} is the mediator as a function of the treatment variable
#' \eqn{X}, and \eqn{w_i} are confounder covariates.
#' 
#' 
#' Unlike the \code{mediation} R package, the estimation in \code{mediate_b} 
#' is fully Bayesian (as opposed to "quasi-Bayesian").
#' 
#' 
#' @references 
#' 
#' Imai, Kosuke, et al. 
#' “A General Approach to Causal Mediation Analysis.” Psychological Methods, 
#' vol. 15, no. 4, 2010, pp. 309–34, https://doi.org/10.1037/a0020761.
#' 
#' @param model_m a fitted model object of class lm_b for mediator.
#' @param model_y a fitted model object of class lm_b for outcome.
#' @param treat a character string indicating the name of the 
#' treatment variable used in the models.  NOTE: Treatment variable must be
#' numeric (even if it's 1's and 0's).
#' @param control_value value of the treatment variable used as the 
#' control condition. Default is the 1st quintile of the treat variable.
#' @param treat_value value of the treatment variable used as the treatment condition. 
#' Default is the 4th quintile of the treat variable.
#' @param n_draws Number of preliminary posterior draws to assess final 
#' number of posterior draws required for accurate interval estimation
#' @param ask_before_full_sampling logical.  If FALSE, the user will not 
#' be asked if they want to complete the full sampling.  Defaults to 
#' TRUE, as this can be a computationally intensive procedure.
#' @param CI_level numeric. Credible interval level.
#' @param seed integer.  Always set your seed!!!
#' @param mc_error positive scalar.  The number of posterior samples will, 
#' with high probability, estimate the CI bounds up to 
#' \eqn{\pm}\code{mc_error}\eqn{\times}\code{sd(y)}.
#' @param batch_size positive integer.  Number of posterior draws to be 
#' taken at once.  Higher values are more computationally intensive, but 
#' values which are too high might take up significant memory (allocates 
#' on the order of \code{batch_size}\eqn{\times}\code{nrow(model_y$data)}).
#' 
#' @returns A list with the following elements:
#' \itemize{
#'  \item \code{summary} - tibble giving results for causal mediation quantities
#'  \item \code{posterior_draws} (of counterfactual expectations)
#'  \item \code{mc_error} absolute error used, including any rescaling 
#'  to match the scale of the outcome
#'  \item other inputs to \code{mediate_b}
#' }
#' 
#' 
#' @examples
#' \donttest{
#' # Simplest case
#' ## Generate some data
#' set.seed(2025)
#' N = 500
#' test_data = 
#'   data.frame(tr = rnorm(N),
#'              x1 = rnorm(N))
#' test_data$m = 
#'   rnorm(N, 0.4 * test_data$tr - 0.25 * test_data$x1)
#' test_data$outcome = 
#'   rnorm(N,-1 + 0.6 * test_data$tr + 1.5 * test_data$m + 0.25 * test_data$x1)
#' 
#' ## Fit the mediator and outcome models
#' m1 = 
#'   lm_b(m ~ tr + x1,
#'        data = test_data)
#' m2 = 
#'   lm_b(outcome ~ m + tr + x1,
#'        data = test_data)
#' ## Estimate the causal mediation quantities
#' m3 <-
#'   mediate_b(m1,m2,
#'             treat = "tr",
#'             control_value = -2,
#'             treat_value = 2,
#'             n_draws = 500,
#'             mc_error = 0.05,
#'             ask_before_full_sampling = FALSE)
#' m3
#' summary(m3,
#'         CI_level = 0.9)
#' 
#' # More complicated scenario
#' ## Generate some data
#' set.seed(2025)
#' N = 500
#' test_data = 
#'   data.frame(tr = rep(0:1,N/2),
#'              x1 = rnorm(N))
#' test_data$m = 
#'   rnorm(N, 0.4 * test_data$tr - 0.25 * test_data$x1)
#' test_data$outcome = 
#'   rpois(N,exp(-1 + 0.6 * test_data$tr + 1.5 * test_data$m + 0.25 * test_data$x1))
#' 
#' ## Fit the mediator and outcome models
#' m1 = 
#'   lm_b(m ~ tr + x1,
#'        data = test_data)
#' m2 = 
#'   glm_b(outcome ~ m + tr + x1,
#'         data = test_data,
#'         family = poisson())
#' 
#' ##  Estimate the causal mediation quantities
#' m3 <-
#'   mediate_b(m1,m2,
#'             treat = "tr",
#'             control_value = 0,
#'             treat_value = 1,
#'             n_draws = 500,
#'             mc_error = 0.05,
#'             ask_before_full_sampling = FALSE)
#' summary(m3)
#' }
#' 
#' 
#' 
#' @export

mediate_b = function(model_m,
                     model_y,
                     treat,
                     control_value,
                     treat_value,
                     n_draws = 500,
                     ask_before_full_sampling = TRUE,
                     CI_level = 0.95,
                     seed = 1,
                     mc_error = ifelse("glm_b" %in% model_y,
                                       0.01,0.002),
                     batch_size = 500){
  set.seed(seed)
  alpha_ci = 1 - CI_level
  
  # Get mc_error
  if("lm_b" %in% class(model_y)){
    y =
      model.response(model.frame(terms(model_y),
                                 model_y$data))
    mc_error = mc_error * 4 * sd(y)
  }
  if( ("glm_b" %in% class(model_y)) &&
      (model_y$family$family != "binomial") ){
    y =
      model.response(model.frame(terms(model_y),
                                 model_y$data))
    mc_error = mc_error * 4 * sd(log(y + 1))
  }
  
  
  
  
  if(!all.equal(model_m$data,model_y$data)){
    stop("Data in model_m and model_y must match.")
  }
  
  mediator = as.character(model_m$formula)[[2]]
  
  if(missing(control_value)){
    message(paste0("control_value missing; set to be the 1st quintile of ",
                   treat))
    control_value = quantile(model_m$data[[treat]],probs = 0.2)
  }
  if(missing(treat_value)){
    message(paste0("treat_value missing; set to be the 4th quintile of ",
                   treat))
    treat_value = quantile(model_m$data[[treat]],probs = 0.8)
  }
  
  tl = attr(model_y$terms,"term.labels")
  simple = 
    ("lm_b" %in% class(model_m)) &
    ("lm_b" %in% class(model_y)) & 
    !any(grepl(paste0(":",mediator),tl) | 
           grepl(paste0(mediator,":"),tl))
  
  results = list()
  if(simple){
    
    # Get posterior draws for ACME and ADE
    ## Get preliminary draws
    mediator_draws = 
      get_posterior_draws(model_m,
                          n_draws = n_draws)
    outcome_draws = 
      get_posterior_draws(model_y,
                          n_draws = n_draws)
    results$posterior_draws = 
      tibble::tibble(
        ACME = 
          (treat_value - control_value) * 
          mediator_draws[,treat] *
          outcome_draws[,all.vars(model_m$formula)[1]],
        ADE =
          (treat_value - control_value) * 
          outcome_draws[,treat])
    results$posterior_draws$`Total Effect` = 
      results$posterior_draws$ACME + results$posterior_draws$ADE
    
    ## Evaluate number of draws required for accurate CI bounds
    fhats = 
      future.apply::future_lapply(1:2,
                                  function(i){
                                    density(unlist(results$posterior_draws[,i]),adjust = 2)
                                  })
    
    n_more_draws = 
      future.apply::future_sapply(1:2,
                                  function(i){
                                    0.5 * alpha_ci * (1.0 - 0.5 * alpha_ci) *
                                      (
                                        qnorm(0.5 * (1.0 - 0.99)) / 
                                          mc_error /
                                          fhats[[i]]$y[which.min(abs(fhats[[i]]$x - 
                                                                       quantile(unlist(results$posterior_draws[,i]), 0.5 * alpha_ci)))]
                                      )^2
                                  }) |> 
      max() |> 
      round() - n_draws
    
    # Finish sampling
    user_response = TRUE
    give_warning = TRUE
    
    if(n_more_draws <= 0 ){
      ask_before_full_sampling = FALSE
      user_response = FALSE
      give_warning = FALSE
    }
    
    if(ask_before_full_sampling){
      user_response = 
        utils::askYesNo(paste0(n_more_draws,
                               " more draws are required for accurate CI bounds.\nShould sampling proceed? (yes/no)"))
    }
    
    if(user_response){
      message("Continuing on with ",
              n_more_draws,
              " more posterior samples.\n")
      
      mediator_draws = 
        get_posterior_draws(model_m,
                            n_draws = n_draws)
      outcome_draws = 
        get_posterior_draws(model_y,
                            n_draws = n_draws)
      next_draws = 
        tibble::tibble(
          ACME = 
            (treat_value - control_value) * 
            mediator_draws[,treat] *
            outcome_draws[,all.vars(model_m$formula)[1]],
          ADE =
            (treat_value - control_value) * 
            outcome_draws[,treat])
      next_draws$`Total Effect` = 
        next_draws$ACME + next_draws$ADE
      
      results$posterior_draws = 
        bind_rows(results$posterior_draws,
                  next_draws)
      
    }else{
      results$message = 
        paste0(n_draws + n_more_draws,
               " total draws are required for accurate CI bounds.")
      if(give_warning) message(results$message)
    }
    
    results$summary = 
      tibble::tibble(Estimand = c("ACME",
                                  "ADE",
                                  "Total Effect",
                                  "Prop. Mediated"),
                     Estimate = 
                       c(mean(results$posterior_draws$ACME),
                         mean(results$posterior_draws$ADE),
                         mean(results$posterior_draws$`Total Effect`),
                         mean(results$posterior_draws$ACME / 
                                results$posterior_draws$`Total Effect`)),
                     Lower = 
                       c(quantile(results$posterior_draws$ACME,
                                  probs = 0.5 * alpha_ci),
                         quantile(results$posterior_draws$ADE,
                                  probs = 0.5 * alpha_ci),
                         quantile(results$posterior_draws$`Total Effect`,
                                  probs = 0.5 * alpha_ci),
                         quantile(results$posterior_draws$ACME / 
                                    results$posterior_draws$`Total Effect`,
                                  probs = 0.5 * alpha_ci)),
                     Upper =
                       c(quantile(results$posterior_draws$ACME,
                                  probs = 1.0 - 0.5 * alpha_ci),
                         quantile(results$posterior_draws$ADE,
                                  probs = 1.0 - 0.5 * alpha_ci),
                         quantile(results$posterior_draws$`Total Effect`,
                                  probs = 1.0 - 0.5 * alpha_ci),
                         quantile(results$posterior_draws$ACME / 
                                    results$posterior_draws$`Total Effect`,
                                  probs = 1.0 - 0.5 * alpha_ci)),
                     `Prob Dir` = 
                       c(mean(results$posterior_draws$ACME > 0),
                         mean(results$posterior_draws$ADE > 0),
                         mean(results$posterior_draws$`Total Effect` > 0),
                         NA)
      )
    
    
  }else{
    
    if(("glm_b" %in% class(model_m)) && (model_m$algorithm == "IS")){
      suppressMessages({
        model_m <- 
          glm_b(formula = model_m$formula,
                data = model_m$data,
                family = model_m$family,
                trials = model_m$trials,
                prior_beta_mean = model_m$hyperparameters$prior_beta_mean,
                prior_beta_precision = model_m$hyperparameters$prior_beta_precision,
                algorithm = "VB")
      })
    }
    if(("glm_b" %in% class(model_y)) && (model_y$algorithm == "IS")){
      suppressMessages({
        model_m <- 
          glm_b(formula = model_y$formula,
                data = model_y$data,
                family = model_y$family,
                trials = model_y$trials,
                prior_beta_mean = model_y$hyperparameters$prior_beta_mean,
                prior_beta_precision = model_y$hyperparameters$prior_beta_precision,
                algorithm = "VB")
      })
    }
    
    # Setup counterfactual data for posterior draws
    counterfactual_data0 = 
      counterfactual_data1 = 
      model_m$data
    counterfactual_data0[[treat]] = control_value
    counterfactual_data1[[treat]] = treat_value
    
    
    # Create sampling function
    draw_y_x_mx = function(n_iter){
      
      ## Draw new mediators
      M_0 =
        predict(model_m,
                newdata = counterfactual_data0,
                n_draws = n_iter)
      M_0 = 
        M_0[,setdiff(colnames(M_0),
                     c(mediator,
                       "Post Mean",
                       "PI_lower",
                       "PI_upper",
                       "CI_lower",
                       "CI_upper"))]
      M_0 = 
        M_0 |> 
        tidyr::pivot_longer(cols = contains("y_new"),
                            names_to = "posterior_draw",
                            values_to = mediator,
                            names_prefix = "y_new")
      M_1 =
        predict(model_m,
                newdata = counterfactual_data1,
                n_draws = n_iter)
      M_1 = 
        M_1[,setdiff(colnames(M_1),
                     c(mediator,
                       "Post Mean",
                       "PI_lower",
                       "PI_upper",
                       "CI_lower",
                       "CI_upper"))]
      M_1 = 
        M_1 |> 
        tidyr::pivot_longer(cols = contains("y_new"),
                            names_to = "posterior_draw",
                            values_to = mediator,
                            names_prefix = "y_new")
      gc_output = 
        utils::capture.output({gc()})
      
      
      y_00 =
        predict(model_y,
                newdata = M_0,
                n_draws = 1)
      gc_output = 
        utils::capture.output({gc()})
      y_11 =
        predict(model_y,
                newdata = M_1,
                n_draws = 1)
      gc_output = 
        utils::capture.output({gc()})
      M_1[[treat]] = control_value
      y_01 =
        predict(model_y,
                newdata = M_1,
                n_draws = 1)
      gc_output = 
        utils::capture.output({gc()})
      M_0[[treat]] = treat_value
      y_10 =
        predict(model_y,
                newdata = M_0,
                n_draws = 1)
      gc_output = 
        utils::capture.output({gc()})
      
      
      E_00 = 
        y_00 |> 
        dplyr::group_by(.data$posterior_draw) |> 
        dplyr::summarize(mean = mean(.data[["y_new1"]],)) |> 
        dplyr::pull(mean)
      E_11 = 
        y_11 |> 
        dplyr::group_by(.data$posterior_draw) |> 
        dplyr::summarize(mean = mean(.data[["y_new1"]])) |> 
        dplyr::pull(mean)
      E_01 = 
        y_01 |> 
        dplyr::group_by(.data$posterior_draw) |> 
        dplyr::summarize(mean = mean(.data[["y_new1"]])) |> 
        dplyr::pull(mean)
      E_10 = 
        y_10 |> 
        dplyr::group_by(.data$posterior_draw) |> 
        dplyr::summarize(mean = mean(.data[["y_new1"]])) |> 
        dplyr::pull(mean)
      
      ret = 
        tibble::tibble(Tot_Eff = E_11 - E_00,
                       ACME_control = E_01 - E_00,
                       ACME_treat = E_11 - E_10,
                       ADE_control = E_10 - E_00,
                       ADE_treat = E_11 - E_01)
      
      rm(M_0,M_1,y_00,y_11,y_10,y_01,E_00,E_11,E_10,E_01)
      gc_output = 
        utils::capture.output({gc()})
      
      return( ret )
    }
    
    # Get preliminary posterior draws
    prelim_draws = 
      draw_y_x_mx(n_draws) |> 
      na.omit()
    message(paste0("\nFinished with ",
               n_draws,
               " preliminary posterior draws.\n"))
    
    
    ## Evaluate number of draws required for accurate CI bounds
    fhats = 
      future.apply::future_lapply(2:NCOL(prelim_draws),
                                  function(i){
                                    stats::density(unlist(prelim_draws[,i]),adjust = 2)
                                  })
    
    n_more_draws = 
      future.apply::future_sapply(2:NCOL(prelim_draws),
                                  function(i){
                                    0.5 * alpha_ci * (1.0 - 0.5 * alpha_ci) *
                                      (
                                        qnorm(0.5 * (1.0 - 0.99)) / 
                                          mc_error /
                                          fhats[[i - 1]]$y[which.min(abs(fhats[[i - 1]]$x - 
                                                                           quantile(unlist(prelim_draws[,i]), 0.5 * alpha_ci)))]
                                      )^2
                                  }) |> 
      max() |> 
      round() - n_draws
    
    
    user_response = TRUE
    give_warning = TRUE
    
    if(n_more_draws <= 0 ){
      ask_before_full_sampling = FALSE
      user_response = FALSE
      give_warning = FALSE
    }
    
    if(ask_before_full_sampling){
      user_response = 
        utils::askYesNo(paste0(n_more_draws,
                               " more draws are required for accurate CI bounds.\nShould sampling proceed? (yes/no)"))
    }
    
    if(user_response){
      message("Continuing on with ",
          n_more_draws,
          " more posterior samples.\n")
      
      # Do it in batches to save memory
      batch_size_vector = 
        c(seq(1,n_more_draws,by = batch_size),n_more_draws + 1) |> 
        diff() |> 
        pmax(2)
      results$posterior_draws = 
        do.call(dplyr::bind_rows,
                future.apply::future_lapply(1:length(batch_size_vector),
                                            function(b){
                                              suppressWarnings(suppressPackageStartupMessages(library(bayesics)))
                                              draw_y_x_mx(batch_size_vector[b])
                                            },
                                            future.seed = seed + 1)
        ) |> 
        na.omit()
    }else{
      results$message = 
        paste0(n_draws + n_more_draws,
               " total draws are required for accurate CI bounds.")
      results$posterior_draws = 
        prelim_draws
      if(give_warning) warning(results$message)
    }
    
    
    
    # Put it together to return
    results$summary = 
      tibble::tibble(Estimand = c("ACME (Control)",
                                  "ACME (Treatment)",
                                  "ADE (Control)",
                                  "ADE (Treatment)",
                                  "Total Effect",
                                  "ACME (Average)",
                                  "ADE (Average)",
                                  "Prop. Mediated (Average)"),
                     Estimate = 
                       c(mean(results$posterior_draws$ACME_control),
                         mean(results$posterior_draws$ACME_treat),
                         mean(results$posterior_draws$ADE_control),
                         mean(results$posterior_draws$ADE_treat),
                         mean(results$posterior_draws$Tot_Eff),
                         0.5 * mean(results$posterior_draws$ACME_control + 
                                      results$posterior_draws$ACME_treat),
                         0.5 * mean(results$posterior_draws$ADE_control + 
                                      results$posterior_draws$ADE_treat),
                         mean( (results$posterior_draws$ACME_control + 
                                  results$posterior_draws$ACME_treat) / 
                                 (results$posterior_draws$ACME_control + 
                                    results$posterior_draws$ACME_treat + 
                                    results$posterior_draws$ADE_control + 
                                    results$posterior_draws$ADE_treat) )
                       ),
                     Lower = 
                       c(quantile(results$posterior_draws$ACME_control,0.5 * alpha_ci),
                         quantile(results$posterior_draws$ACME_treat,0.5 * alpha_ci),
                         quantile(results$posterior_draws$ADE_control,0.5 * alpha_ci),
                         quantile(results$posterior_draws$ADE_treat,0.5 * alpha_ci),
                         quantile(results$posterior_draws$Tot_Eff,0.5 * alpha_ci),
                         0.5 * quantile(results$posterior_draws$ACME_control + 
                                          results$posterior_draws$ACME_treat,0.5 * alpha_ci),
                         0.5 * quantile(results$posterior_draws$ADE_control + 
                                          results$posterior_draws$ADE_treat,0.5 * alpha_ci),
                         quantile( (results$posterior_draws$ACME_control + 
                                      results$posterior_draws$ACME_treat) / 
                                     (results$posterior_draws$ACME_control + 
                                        results$posterior_draws$ACME_treat + 
                                        results$posterior_draws$ADE_control + 
                                        results$posterior_draws$ADE_treat), 0.5 * alpha_ci )
                       ),
                     Upper = 
                       c(quantile(results$posterior_draws$ACME_control,1.0 - 0.5 * alpha_ci),
                         quantile(results$posterior_draws$ACME_treat,1.0 - 0.5 * alpha_ci),
                         quantile(results$posterior_draws$ADE_control,1.0 - 0.5 * alpha_ci),
                         quantile(results$posterior_draws$ADE_treat,1.0 - 0.5 * alpha_ci),
                         quantile(results$posterior_draws$Tot_Eff,1.0 - 0.5 * alpha_ci),
                         0.5 * quantile(results$posterior_draws$ACME_control + 
                                          results$posterior_draws$ACME_treat,1.0 - 0.5 * alpha_ci),
                         0.5 * quantile(results$posterior_draws$ADE_control + 
                                          results$posterior_draws$ADE_treat,1.0 - 0.5 * alpha_ci),
                         quantile( (results$posterior_draws$ACME_control + 
                                      results$posterior_draws$ACME_treat) / 
                                     (results$posterior_draws$ACME_control + 
                                        results$posterior_draws$ACME_treat + 
                                        results$posterior_draws$ADE_control + 
                                        results$posterior_draws$ADE_treat), 1.0 - 0.5 * alpha_ci )
                       ),
                     `Prob Dir` = 
                       c(mean(results$posterior_draws$ACME_control > 0),
                         mean(results$posterior_draws$ACME_treat > 0),
                         mean(results$posterior_draws$ADE_control > 0),
                         mean(results$posterior_draws$ADE_treat > 0),
                         mean(results$posterior_draws$Tot_Eff > 0),
                         mean(results$posterior_draws$ACME_control + 
                                results$posterior_draws$ACME_treat > 0), # No need to multiply by 0.5 for PDir
                         mean(results$posterior_draws$ADE_control + 
                                results$posterior_draws$ADE_treat > 0),
                         NA)
      )
  }
  
  results$summary$`Prob Dir` = 
    sapply(results$summary$`Prob Dir`,
           function(x) pmax(x, 1.0 - x)
    )
  
  # Don't report negative or >1 proportions
  results$summary$Estimate[nrow(results$summary)] = 
    ifelse(results$summary$Estimate[nrow(results$summary)] < 0,
           0,
           ifelse(results$summary$Estimate[nrow(results$summary)] > 1,
                  1,
                  results$summary$Estimate[nrow(results$summary)]))
  results$summary$Lower[nrow(results$summary)] = 
    min(max(results$summary$Lower[nrow(results$summary)],0.0),1.00)
  results$summary$Upper[nrow(results$summary)] = 
    max(min(results$summary$Upper[nrow(results$summary)],1.0),0.0)
  
  results$treat_value = treat_value
  results$control_value = control_value
  results$model_m = model_m
  results$model_y = model_y
  results$CI_level = CI_level
  results$mc_error = mc_error
  
  
  return(structure(results,
                   class = "mediate_b"))
}

Try the bayesics package in your browser

Any scripts or data that you put into this service are public.

bayesics documentation built on March 11, 2026, 5:07 p.m.