R/plot_moderators.R

Defines functions table_moderator_c_bin rpart_ggplot_ plot_moderator_search plot_moderator_d_linerange plot_moderator_d plot_moderator_c_bin plot_moderator_c_loess plot_moderator_c_pd

Documented in plot_moderator_c_bin plot_moderator_c_loess plot_moderator_c_pd plot_moderator_d plot_moderator_d_linerange plot_moderator_search table_moderator_c_bin

#' @title Partial dependency plot of a continuous moderating variable
#' @description Plot a partial dependency plot with a continuous covariate from a 'bartCause' model. Identify treatment effect variation predicted across levels of a continuous variable.
#'
#' @param .model a model produced by `bartCause::bartc()`
#' @param moderator the moderator as a vector
#' @param n_bins number of bins to cut the moderator with. Defaults to the lesser of 15 and number of distinct levels of the moderator
#' @details Partial dependency plots are one way to evaluate heterogeneous treatment effects that vary by values of a continuous covariate. For more information on partial dependency plots from BART causal inference models see Green and Kern 2012.
#' @author George Perrett, Joseph Marlo
#'
#' @references
#' Green, D. P., & Kern, H. L. (2012).
#' Modeling heterogeneous treatment effects in survey experiments with Bayesian additive regression trees.
#' Public opinion quarterly, 76(3), 491-511.
#'
#' @return ggplot object
#' @export
#'
#' @import ggplot2 dplyr
#' @importFrom stats density median na.omit predict quantile sd
#'
#' @examples
#' \donttest{
#' data(lalonde)
#' confounders <- c('age', 'educ', 'black', 'hisp', 'married', 'nodegr')
#' model_results <- bartCause::bartc(
#'  response = lalonde[['re78']],
#'  treatment = lalonde[['treat']],
#'  confounders = as.matrix(lalonde[, confounders]),
#'  estimand = 'ate',
#'  commonSuprule = 'none',
#'  keepTrees = TRUE
#' )
#' plot_moderator_c_pd(model_results, lalonde$age)
#' }
plot_moderator_c_pd <- function(.model, moderator, n_bins = NULL){

  validate_model_(.model)

  # validate n_bins argument
  n_mod_levels <- n_distinct(moderator)
  if (n_mod_levels <= 1) stop('dplyr::n_distinct(moderator) must be at least 2')
  if (is.null(n_bins)) n_bins <- pclamp_(15, 2, n_mod_levels)
  if (!(n_bins > 1 & n_bins <= n_mod_levels)) stop("n_bins must be greater than 1 and less than or equal to dplyr::n_distinct(moderator)")

  # extract data from model
  new_data <- as_tibble(.model$data.rsp@x)

  # create dataframe where all observations are treated
  new_data_z1 <- new_data
  name_trt <- .model$name.trt
  new_data_z1[, name_trt] <- 1
  new_data_z1 <- cbind.data.frame(y_response = .model$fit.rsp$y, new_data_z1)

  # create dataframe where all observations are control
  new_data_z0 <- new_data
  new_data_z0[, name_trt] <- 0
  new_data_z0 <- cbind.data.frame(y_response = .model$fit.rsp$y, new_data_z0)

  # locate the moderator in bartc data
  search_moderator <- function(x) sum(moderator - x)
  index <- which(lapply(new_data_z0, search_moderator) == 0)
  if (!isTRUE(index > 0)) stop('Cannot find moderator in original data. Is moderator within the original dataframe used to fit the .model?')

  # get range for predictions
  cut <- n_bins-1
  p <- seq(min(moderator), max(moderator), (max(moderator) - min(moderator))/cut)
  if (length(p) < n_distinct(moderator)) {
    .range <- p
  } else{
    .range <- unique(moderator)[order(unique(moderator))]
  }

  # predict new data with overridden treatment columns
  cates <- lapply(.range, fit_pd_, z1 = new_data_z1, z0 = new_data_z0, index = index, .model = .model)
  names(cates) <- seq_along(cates)
  cates <- bind_cols(cates)
  cates.m <- apply(cates, MARGIN = 2, FUN = mean)
  cates.m <- bind_cols(cates.m = cates.m, .range = .range)

  # get credible intervals
  ci_range <- c(0.025, 0.1, 0.9, 0.975)
  cates.ci <- as_tibble(t(apply(cates, MARGIN = 2, FUN = quantile, probs = ci_range)))
  cates_plot <- bind_cols(cates.m, cates.ci)
  indices <- 2 + seq_along(ci_range)
  colnames(cates_plot)[indices] <- paste0('ci_', as.character(ci_range * 100))

  # plot it
  p <- ggplot(cates_plot) +
    geom_ribbon(aes(x = .range, y = cates.m, ymin = ci_2.5, ymax = ci_97.5, fill = '95% ci')) +
    geom_ribbon(aes(x = .range, y = cates.m, ymin = ci_10, ymax = ci_90, fill = '80% ci')) +
    scale_fill_manual(values = c('grey40', 'grey60')) +
    geom_point(aes(x = .range, y = cates.m), size = 2) +
    geom_line(aes(x = .range, y = cates.m)) +
    labs(title = NULL,
         x = NULL,
         y = 'CATE') +
    theme(legend.position = "bottom")

  return(p)
}


#' @title LOESS plot of a continuous moderating variable
#' @description Plot the LOESS prediction of ICATEs by a continuous covariate. This is an alternative to partial dependency plots to assess treatment effect heterogeneity by a continuous covariate. See Carnegie, Dorie and Hill 2019.
#'
#' @param .model a model produced by `bartCause::bartc()`
#' @param moderator the moderator as a vector
#' @param line_color the color of the loess line
#'
#' @author George Perrett, Joseph Marlo
#'
#' @references
#' Carnegie, N., Dorie, V., & Hill, J. L. (2019).
#' Examining treatment effect heterogeneity using BART.
#' Observational Studies, 5(2), 52-70.
#'
#' @return ggplot object
#' @export
#'
#' @import ggplot2 dplyr
#' @importFrom bartCause extract
#'
#' @examples
#' \donttest{
#' data(lalonde)
#' confounders <- c('age', 'educ', 'black', 'hisp', 'married', 'nodegr')
#' model_results <- bartCause::bartc(
#'  response = lalonde[['re78']],
#'  treatment = lalonde[['treat']],
#'  confounders = as.matrix(lalonde[, confounders]),
#'  estimand = 'ate',
#'  commonSuprule = 'none'
#' )
#' plot_moderator_c_loess(model_results, lalonde$age)
#' }
plot_moderator_c_loess <- function(.model, moderator, line_color = 'blue'){

  validate_model_(.model)
  is_numeric_vector_(moderator)

  # adjust moderator to match estimand
  moderator <- adjust_for_estimand_(.model, moderator)

  # extract and rotate posterior
  posterior <- bartCause::extract(.model, 'icate')
  posterior <- posterior %>%
    t() %>%
    as.data.frame() %>%
    as_tibble()

  # split posterior into list of dfs by each level of moderator
  split_posterior <- split(posterior, moderator)
  posterior_means <- lapply(split_posterior, rowMeans)

  # unlist into a data.frame for plotting
  dat <- data.frame(value = unlist(posterior_means))
  dat$moderator <- moderator[order(moderator)]
  rownames(dat) <- seq_len(nrow(dat))

  # plot it
  p <- ggplot(dat, aes(moderator, value)) +
    geom_point() +
    geom_smooth(method = 'loess',
                formula = y ~ x,
                se = TRUE,
                size = 1.5,
                color = line_color) +
    labs(title = NULL,
         x = NULL,
         y = 'icate')

  return(p)
}


#' @title Auto-Bin a plot of a continuous moderating variable into a discrete moderating variable
#' @description Use a regression tree to optimally bin a continous variable
#'
#' @param .model a model produced by `bartCause::bartc()`
#' @param moderator the moderator as a vector
#' @param .alpha transparency value [0, 1]
#' @param facet TRUE/FALSE. Create panel plots of each moderator level?
#' @param .ncol number of columns to use when faceting
#' @param type string to specify if you would like to plot a histogram, density or error bar plot
#'@param .name sting representing the name of the moderating variable
#'
#' @author George Perrett
#'
#'
#' @return ggplot object
#' @export
#'
#' @import ggplot2 dplyr
#' @importFrom bartCause extract
#' @importFrom rpart rpart
#'
#' @examples
#' \donttest{
#' data(lalonde)
#' confounders <- c('age', 'educ', 'black', 'hisp', 'married', 'nodegr')
#' model_results <- bartCause::bartc(
#'  response = lalonde[['re78']],
#'  treatment = lalonde[['treat']],
#'  confounders = as.matrix(lalonde[, confounders]),
#'  estimand = 'ate',
#'  commonSuprule = 'none'
#' )
#' plot_moderator_c_bin(model_results, lalonde$age, .name = 'age')
#' }
plot_moderator_c_bin <- function(.model, moderator,type = c('density', 'histogram', 'errorbar'), .alpha = 0.7, facet = FALSE, .ncol = 1, .name = 'bin'){

  validate_model_(.model)
  is_numeric_vector_(moderator)
  type <- type[1]

  # adjust moderator to match estimand
  moderator <- adjust_for_estimand_(.model, moderator)
  estimand <- switch (.model$estimand,
                      ate = 'CATE',
                      att = 'CATT',
                      atc = 'CATC'
  )

  # extract the posterior
  posterior <- bartCause::extract(.model, 'icate')

  # get icate point est
  icate.m <- apply(posterior, 2, mean)

  # fit regression tree
  tree <- rpart::rpart(icate.m ~ moderator)

  # get bins from regression tree
  bins <- dplyr::tibble(splits = tree$where,
                        x = moderator)

  subgroups <- dplyr::tibble(splits = tree$where,
                             x = moderator) %>%
    dplyr::group_by(splits) %>%
    dplyr::summarise(min = min(x), max = max(x)) %>%
    dplyr::arrange(min) %>%
    dplyr::mutate(subgroup = paste0(.name,':', round(min, 2) ,'-', round(max, 2)))

  bins <- bins %>% dplyr::left_join(subgroups)

  # roatate posterior
  posterior <- posterior %>%
    t() %>%
    as.data.frame() %>%
    as_tibble() %>%
    mutate(moderator = bins$subgroup)

  # marginalize
  posterior <- posterior %>%
    group_by(moderator) %>%
    summarise_all(mean) %>%
    pivot_longer(cols = 2:ncol(posterior))



  # plot it
  p <- ggplot(posterior, aes(value, fill = moderator))

  if(type == 'density'){
    p <- p + geom_density(alpha = .alpha) +
      labs(title = NULL,
           x = estimand,
           y = NULL) +
      theme(legend.position = 'bottom')
  }else if(type == 'histogram'){
    p <- p +
      geom_histogram(
        alpha = .alpha,
        col = 'black',
        position = 'identity') +
      labs(title = NULL,
           x = estimand,
           y = NULL) +
      theme(legend.position = 'bottom')
  } else{
    # tidy up the data
    dat <- dat %>%
      group_by(moderator) %>%
      mutate(.min = quantile(value, .025),
             .max = quantile(value, .975),
             point = mean(value)) %>%
      dplyr::select(-value) %>%
      arrange(desc(point)) %>%
      ungroup() %>%
      distinct()

    # plot it
    p <- ggplot(dat, aes(x = moderator, y = point, color = moderator)) +
      geom_point(size = 2) +
      geom_linerange(aes(ymin = .min, ymax = .max), alpha = .alpha) +
      labs(title = NULL,
           x = element_blank(),
           y = estimand) +
      theme(legend.position = 'bottom')
  }


  # add faceting
  if(isTRUE(facet)){
    p <- p + facet_wrap(~moderator, ncol = .ncol)
  }

  return(p)
}

#' @title Plot the Conditional Average Treatment Effect conditional on a discrete moderator
#' @description Plot the Conditional Average Treatment Effect split by a discrete moderating variable. This plot will provide a visual test of moderation by discrete variables.
#'
#' @param .model a model produced by `bartCause::bartc()`
#' @param moderator the moderator as a vector
#' @param .alpha transparency value [0, 1]
#' @param facet TRUE/FALSE. Create panel plots of each moderator level?
#' @param .ncol number of columns to use when faceting
#' @param type string to specify if you would like to plot a histogram, density or error bar plot
#'
#' @author George Perrett
#'
#' @return ggplot object
#' @export
#'
#' @import ggplot2 dplyr
#' @importFrom bartCause extract
#'
#' @examples
#' \donttest{
#' data(lalonde)
#' confounders <- c('age', 'educ', 'black', 'hisp', 'married', 'nodegr')
#' model_results <- bartCause::bartc(
#'  response = lalonde[['re78']],
#'  treatment = lalonde[['treat']],
#'  confounders = as.matrix(lalonde[, confounders]),
#'  estimand = 'ate',
#'  commonSuprule = 'none'
#' )
#' plot_moderator_d(model_results, lalonde$educ)
#' }
plot_moderator_d <- function(.model, moderator, type = c('density', 'histogram', 'errorbar'), .alpha = 0.7, facet = FALSE, .ncol = 1){
  type <- type[1]
  validate_model_(.model)
  is_discrete_(moderator)
  if(type %notin%c('density', 'histogram', 'errorbar')) stop('type must be either density, histogram errorbar')
  # adjust moderator to match estimand
  moderator <- adjust_for_estimand_(.model, moderator)

  estimand <- switch (.model$estimand,
    ate = 'CATE',
    att = 'CATT',
    atc = 'CATC'
  )
  # extract and rotate posterior
  posterior <- bartCause::extract(.model, 'icate')

  posterior <- posterior %>%
    t() %>%
    as.data.frame() %>%
    as_tibble()

  # split posterior into list of dfs by each level of moderator
  split_posterior <- split(posterior, moderator)
  posterior_means <- lapply(split_posterior, colMeans)

  # unlist into a data.frame for plotting
  dat <- data.frame(value = unlist(posterior_means))
  dat$moderator <- sub("\\..*", '', rownames(dat))
  rownames(dat) <- seq_len(nrow(dat))

  # plot it
  p <- ggplot(dat, aes(value, fill = moderator))

  if(type == 'density'){
    p <- p + geom_density(alpha = .alpha) +
      labs(title = NULL,
           x = estimand,
           y = NULL) +
      theme(legend.position = 'bottom')
  }else if(type == 'histogram'){
    p <- p +
      geom_histogram(
        alpha = .alpha,
        col = 'black',
        position = 'identity') +
      labs(title = NULL,
           x = estimand,
           y = NULL) +
      theme(legend.position = 'bottom')
  } else{
    # tidy up the data
    dat <- dat %>%
      group_by(moderator) %>%
      mutate(.min = quantile(value, .025),
             .max = quantile(value, .975),
             point = mean(value)) %>%
      dplyr::select(-value) %>%
      arrange(desc(point)) %>%
      ungroup() %>%
      distinct()

    # plot it
    p <- ggplot(dat, aes(x = moderator, y = point, color = moderator)) +
      geom_point(size = 2) +
      geom_linerange(aes(ymin = .min, ymax = .max), alpha = .alpha) +
      labs(title = NULL,
           x = element_blank(),
           y = estimand) +
      theme(legend.position = 'bottom')
  }


  # add faceting
  if(isTRUE(facet)){
    p <- p + facet_wrap(~moderator, ncol = .ncol)
  }

  return(p)
}



#' @title Plot the posterior interval of the Conditional Average Treatment Effect grouped by a discrete variable
#' @description Plots the range of the Conditional Average Treatment Effect grouped by a discrete variable. This is analogous to plot_moderator_d_density but is preferable for moderators with many categories. Rather than plotting the full density, the posterior range is shown.
#'
#' @param .model a model produced by `bartCause::bartc()`
#' @param moderator the moderator as a vector
#' @param .alpha transparency value [0, 1]
#' @param horizontal flip the plot horizontal?
#'
#' @author George Perrett, Joseph Marlo
#'
#' @return ggplot object
#' @export
#'
#' @import ggplot2 dplyr
#' @importFrom bartCause extract
#'
#' @examples
#' \donttest{
#' data(lalonde)
#' confounders <- c('age', 'educ', 'black', 'hisp', 'married', 'nodegr')
#' model_results <- bartCause::bartc(
#'  response = lalonde[['re78']],
#'  treatment = lalonde[['treat']],
#'  confounders = as.matrix(lalonde[, confounders]),
#'  estimand = 'ate',
#'  commonSuprule = 'none'
#' )
#' plot_moderator_d_linerange(model_results, lalonde$educ)
#' }
plot_moderator_d_linerange <- function(.model, moderator, .alpha = 0.7, horizontal = FALSE){

  validate_model_(.model)
  is_discrete_(moderator)

  # adjust moderator to match estimand
  moderator <- adjust_for_estimand_(.model, moderator)

  # extract and rotate posterior
  posterior <- bartCause::extract(.model, 'icate')
  posterior <- posterior  %>%
    t() %>%
    as.data.frame() %>%
    as_tibble()

  # split posterior into list of dfs by each level of moderator
  split_posterior <- split(posterior, moderator)
  posterior_means <- lapply(split_posterior, colMeans)

  # unlist into a data.frame for plotting
  dat <- data.frame(value = unlist(posterior_means))
  dat$moderator <- sub("\\..*", '', rownames(dat))
  rownames(dat) <- seq_len(nrow(dat))

  # tidy up the data
  dat <- dat %>%
    group_by(moderator) %>%
    mutate(.min = quantile(value, .025),
           .max = quantile(value, .975),
           point = mean(value)) %>%
    dplyr::select(-value) %>%
    arrange(desc(point)) %>%
    ungroup() %>%
    distinct()

  # plot it
  p <- ggplot(dat, aes(x = moderator, y = point, color = moderator)) +
    geom_point(size = 2) +
    geom_linerange(aes(ymin = .min, ymax = .max), alpha = .alpha) +
    labs(title = NULL,
         x = element_blank(),
         y = 'CATE') +
    theme(legend.position = 'bottom')

  if (horizontal) p <- p + coord_flip()

  return(p)
}

#' @title Plot a single regression tree of covariates on ICATEs
#' @description Plot a single regression tree for exploratory heterogeneous effects. Fit single regression tree on bartc() ICATEs to produce variable importance plot. This plot is useful for identifying potential moderating variables.
#' Tree depth may be set to depths 1, 2 or 3. Terminal nodes signal the Conditional Average Treatment effect within levels of moderation variables. Trees with different values across terminal nodes suggest strong treatment effect moderation.
#'
#' @param .model a model produced by `bartCause::bartc()`
#' @param max_depth one of c(1, 2, 3). Maximum number of node levels within the tree. 2 is recommended
#'
#' @author George Perrett, Joseph Marlo
#'
#' @import ggplot2 dplyr
#' @importFrom ggdendro dendro_data
#' @importFrom rpart rpart
#' @importFrom bartCause extract
#'
#' @return ggplot object
#' @export
#'
#' @examples
#' \donttest{
#' data(lalonde)
#' confounders <- c('age', 'educ', 'black', 'hisp', 'married', 'nodegr')
#' model_results <- bartCause::bartc(
#'  response = lalonde[['re78']],
#'  treatment = lalonde[['treat']],
#'  confounders = as.matrix(lalonde[, confounders]),
#'  estimand = 'ate',
#'  commonSuprule = 'none'
#' )
#' plot_moderator_search(model_results)
#' }
plot_moderator_search <- function(.model, max_depth = c(2, 1, 3)){

  validate_model_(.model)

  max_depth <- max_depth[[1]]
  if (max_depth %notin% c(2, 1, 3)) stop('max_depth must be one of c(1, 2, 3)')

  icate <- bartCause::extract(.model , 'icate')
  icate.m <- apply(icate, 2, mean)

  # pull data from model and create a matrix of confounders
  .data <- as.data.frame(.model$data.rsp@x)

  # adjust data for estimand
  if (.model$estimand == 'ate') {
    confounders <- as.matrix(.data[, c(-1,-(length(.data)))])
  } else if (.model$estimand == 'att') {
    .data <- .data[.model$trt == 1,]
    confounders <- as.matrix(.data[, c(-1,-(length(.data)))])
  } else{
    .data <- .data[.model$trt == 0,]
    confounders <- as.matrix(.data[, c(-1,-(length(.data)))])
  }

  # fit regression tree
  cart <- rpart::rpart(icate.m ~ ., data = as.data.frame(confounders), maxdepth = max_depth)
  # p <- rpart.plot::rpart.plot(cart, type = 2, branch = 1, box.palette = 0)

  # create dendrogram
  p_gg <- rpart_ggplot_(cart)

  return(p_gg)
}


rpart_ggplot_ <- function(.model){

  # remove depth information from model so resulting plot is easy to read
  .model$frame$dev <- 1

  # extract data to construct dendrogram
  fitr <- ggdendro::dendro_data(.model)
  n_leaf <- .model$frame$n[.model$frame$var == '<leaf>']
  n_split <- .model$frame$n[.model$frame$var != '<leaf>']
  pred_split <- round(.model$frame$yval[.model$frame$var != '<leaf>'], 1)
  terminal_leaf_y <- 0.1
  leaf_labels <- tibble(
    x = fitr$leaf_labels$x,
    y = terminal_leaf_y,
    label = paste0(
      'y = ', fitr$leaf_labels$label,
      '\nn = ', n_leaf)
  )
  yes_no_offset <- c(0.75, 1.25)
  yes_no <- tibble(
    x = c(fitr$labels$x[[1]] * yes_no_offset[1],
          fitr$labels$x[[1]] * yes_no_offset[2]),
    y = rep(fitr$labels$y[[1]], 2),
    label = c("yes", "no")
  )
  split_labels <- tibble(
    x = fitr$labels$x,
    y = fitr$labels$y + 0.07,
    label = paste0(
      'y = ', pred_split,
      '\nn = ', n_split
    )
  )

  # set terminal segments to y = terminal_leaf_y
  initial_node_y <- fitr$labels$y[[1]]
  fitr$segments <- fitr$segments %>%
    mutate(y_new = ifelse(y > yend, y, yend),
           yend_new = ifelse(yend < y, yend, y)) %>%
    select(n, x, y = y_new, xend, yend = yend_new) %>%
    mutate(y = ifelse(y > initial_node_y, terminal_leaf_y, y),
           yend = ifelse(x == xend & x == round(x) & y > yend, terminal_leaf_y, yend))

  # set plot constants
  label_text_size <- 3
  x_limits <- c(0.5, nrow(fitr$leaf_labels) + 0.5)
  y_limits <- c(min(fitr$segments$y) - 0.05,
                max(fitr$segments$y) + 0.15)

  # plot it
  p <- ggplot() +
    geom_segment(data = fitr$segments,
                 aes(x = x, y = y, xend = xend, yend = yend)) +
    geom_label(data = yes_no,
               aes(x = x, y = y, label = label),
               size = label_text_size) +
    geom_label(data = leaf_labels,
               aes(x = x, y = y, label = label),
               size = label_text_size) +
    geom_label(data = split_labels,
               aes(x = x, y = y, label = label),
               size = label_text_size) +
    geom_label(data = fitr$labels,
               aes(x = x, y = y, label = label),
               label.size = NA, fontface = 'bold') +
    expand_limits(x = x_limits,
                  y = y_limits) +
    scale_x_continuous(labels = NULL, breaks = NULL) +
    scale_y_continuous(labels = NULL, breaks = NULL) +
    labs(title = 'Exploratory heterogeneous effects',
         x = NULL,
         y = NULL) +
    theme(panel.background = element_blank())

  return(p)
}



#' @title Auto-Bin a table of a continuous moderating variable into a discrete moderating variable
#' @description Use a regression tree to optimally bin a continuous variable, this function will print out a table with estimates and 95% ci
#'
#' @param .model a model produced by `bartCause::bartc()`
#' @param moderator the moderator as a vector
#'@param .name sting representing the name of the moderating variable
#'
#' @author George Perrett
#'
#'
#' @return a data.frame object
#' @export
#'
#' @import dplyr
#' @importFrom bartCause extract
#' @importFrom rpart rpart
#' @import stats
#'
#' @examples
#' \donttest{
#' data(lalonde)
#' confounders <- c('age', 'educ', 'black', 'hisp', 'married', 'nodegr')
#' model_results <- bartCause::bartc(
#'  response = lalonde[['re78']],
#'  treatment = lalonde[['treat']],
#'  confounders = as.matrix(lalonde[, confounders]),
#'  estimand = 'ate',
#'  commonSuprule = 'none'
#' )
#' table_moderator_c_bin(model_results, lalonde$age, .name = 'age')
#' }
table_moderator_c_bin <- function(.model, moderator, .name = 'bin'){

  validate_model_(.model)
  is_numeric_vector_(moderator)

  # adjust moderator to match estimand
  moderator <- adjust_for_estimand_(.model, moderator)
  estimand <- switch (.model$estimand,
                      ate = 'CATE',
                      att = 'CATT',
                      atc = 'CATC'
  )

  # extract the posterior
  posterior <- bartCause::extract(.model, 'icate')

  # get icate point est
  icate.m <- apply(posterior, 2, mean)

  # fit regression tree
  tree <- rpart::rpart(icate.m ~ moderator)

  # get bins from regression tree
  bins <- dplyr::tibble(splits = tree$where,
                        x = moderator)

  subgroups <- dplyr::tibble(splits = tree$where,
                             x = moderator) %>%
    dplyr::group_by(splits) %>%
    dplyr::summarise(min = min(x), max = max(x)) %>%
    dplyr::arrange(min) %>%
    dplyr::mutate(subgroup = paste0(.name,':', round(min, 2) ,'-', round(max, 2)))

  bins <- bins %>% dplyr::left_join(subgroups)

  # roatate posterior
  posterior <- posterior %>%
    t() %>%
    as.data.frame() %>%
    as_tibble() %>%
    mutate(moderator = bins$subgroup)

  # marginalize
  posterior <- posterior %>%
    group_by(moderator) %>%
    summarise_all(mean) %>%
    pivot_longer(cols = 2:ncol(posterior))



  posterior %>%
    dplyr::group_by(moderator) |>
    dplyr::mutate(
      est = mean (value),
      sd = sd(value),
      lci = quantile(value, prob = .025),
      uci = quantile(value, prob = .975)) %>%
    dplyr::select(est, sd, lci, uci) %>%
    dplyr::distinct()

}
priism-center/plotBart documentation built on June 2, 2024, 8:50 a.m.