R/pwy_ranef.R

Defines functions plot_pwy_ranef plot_pwy_ranef_intervals anpan_pwy_ranef_batch anpan_pwy_ranef

Documented in anpan_pwy_ranef anpan_pwy_ranef_batch plot_pwy_ranef plot_pwy_ranef_intervals

#' Estimate a species-pathway abundance random effects model
#' @description Fit a model of the form log10_pwy_abd ~ log10_species_abd + (1|pwy) + (0+group|pwy) for a single
#'   bug
#' @details The priors are as follows: A) student_t(3, mean(log10_pwy_abd), 2.5) on the intercept at the
#'   mean species abundance. B) half student_t(3, 0, 2.5) on the residual noise C) 0-centered normal
#'   priors on pathway specific effects C1) half student_t(5,0,2.5) on the SD of pathway intercepts
#'   C2) half standard normal on the SD of group-specific effects.
#'
#'   The pathway index is generated by converting the pwy column to a factor and then to the
#'   corresponding integer index.
#'
#'   The group_ind column should be numeric with values in {0,1}
#'
#'   The main parameter of interest are the elements of the  \code{pwy_effects} parameter. The "hit"
#'   column is defined by selecting the bug:pwy combinations where 1) 98\% posterior intervals for
#'   the pwy:group effect exclude 0, 2) the absolute posterior mean exceeds the specified effect
#'   size threshold, and 3) the estimated fixed effect of log10_species_abd on log10_pwy_abd is positive.
#' @param bug_pwy_dat a data frame with a row for each observation and columns "pwy",
#'   "log10_species_abd", "log10_pwy_abd", and a group indicator column named according to the
#'   \code{group_ind} argument
#' @param group_ind a character giving the name of the column for the 0/1 group indicator variable
#'   in \code{bug_pwy_dat}
#' @param group_exp_rate rate parameter of the exponential distribution prior on group effects
#' @param effect_size_threshold effect size threshold for hit-calling pathway:group effects
#' @param ... other arguments to pass to cmdstanr::sample()
#' @returns a list containing the CmdStanMCMC object of the model fit and a summary data frame.
#' @seealso \code{\link[=plot_pwy_ranef]{plot_pwy_ranef()}} , \code{\link[=plot_pwy_ranef_intervals]{plot_pwy_ranef_intervals()}}, \code{\link[cmdstanr:CmdStanMCMC]{cmdstanr::CmdStanMCMC}}
#' @export
anpan_pwy_ranef = function(bug_pwy_dat,
                           group_ind = "crc",
                           effect_size_threshold = .25,
                           group_exp_rate = 3,
                           ...) {

  if (!all(c("pwy", "log10_species_abd", "log10_pwy_abd", group_ind) %in% names(bug_pwy_dat))) {
    stop("The necessary variables are not present in bug_pwy_dat. See ?anpan_pwy_ranef for what the columns should be called.")
  }

  if (!is.factor(bug_pwy_dat$pwy)) {
    warning("Converting the pwy column to a factor.")
  }

  model_path = system.file("stan", "pwy_ranef.stan",
                           package = 'anpan',
                           mustWork = TRUE)

  pwy_ranef_model = cmdstanr::cmdstan_model(stan_file = model_path,
                                            quiet = TRUE)

  data_list = list(N = nrow(bug_pwy_dat),
                   pwy_abd = bug_pwy_dat$log10_pwy_abd,
                   pwy_mean = mean(bug_pwy_dat$log10_pwy_abd),
                   intercept_species = model.matrix(~log10_species_abd, data = bug_pwy_dat),
                   N_pwy = dplyr::n_distinct(bug_pwy_dat$pwy),
                   pwy_ind = as.integer(factor(bug_pwy_dat$pwy)),
                   group_ind = bug_pwy_dat[[group_ind]],
                   group_exp_rate = group_exp_rate)

  pwy_ind_map = tibble(index = 1:data_list$N_pwy,
                       pwy_group_effect = paste("pwy_effects[", index, "]", sep = ""),
                       pwy_intercept = paste("pwy_intercepts[", index, "]", sep = ""),
                       pwy = levels(factor(bug_pwy_dat$pwy))) |>
    data.table::as.data.table() |>
    data.table::melt(id.vars = c("index", "pwy"),
                     variable.name = 'var_names',
                     value.name = 'variable') |>
    dplyr::select(-var_names) |>
    dplyr::arrange(index) |>
    data.table::as.data.table()

  mod_fit = pwy_ranef_model$sample(data = data_list, ...)

  summary_df = mod_fit$draws() |>
    posterior::summarise_draws(posterior::default_summary_measures(),
                               wide = ~purrr::set_names(quantile(.x, probs = c(.01, .99)), c("q1", "q99")),
                               posterior::default_convergence_measures()) |>
    filter(grepl("^pwy|global|sigma|sd_|species_beta", variable)) |>  # Discard variables most users won't be interested in like lp, lprior
    dplyr::left_join(pwy_ind_map, by = 'variable') |>
    mutate(hit = (!(q1 < 0 & q99 > 0)) &
             (abs(mean) > effect_size_threshold) &
             mean[grepl("species_beta", variable)] > 0) |>
    select(pwy, hit, variable:ess_tail)

  summary_df$hit[!grepl("^pwy_eff", summary_df$variable)] = NA

  return(tibble(model_fit = list(mod_fit),
                summary_df = list(summary_df)))
}

safely_anpan_pwy_ranef = purrr::safely(anpan_pwy_ranef)

#' Fit the pathway random effects model for multiple bugs
#' @details In addition to the column requirements of \code{anpan_pwy_ranef()}, the input data frame
#'   \code{bug_pwy_dat} here must also contain a variable called "bug" which gives a unique
#'   identifier for each bug.
#' @param out_dir output directory
#' @returns a tibble of row-binded anpan_pwy_ranef results
#' @inheritParams anpan_pwy_ranef
#' @examples \dontrun{
#' library(tidyverse)
#' library(anpan)
#'
#' set.seed(123)
#' input_dat = tibble(bug = rep(paste0("bug", 1:5), each = 200),
#'                    pwy = rep(paste0('pwy', 1:5), times = 200),
#'                    log10_species_abd = rnorm(1000),
#'                    log10_pwy_abd = rnorm(1000, mean = .8*log10_species_abd),
#'                    group = sample(c(0,1), size = 1000, replace = TRUE))
#'  # ^ the pathway and and group are NOT related in any pathway.
#'
#' res = anpan_pwy_ranef_batch(input_dat, group_ind = "group")
#'
#' # Examine the summary
#' pwy_group_res = res |>
#'   dplyr::select(bug, summary_df) |>             # select the two main columns
#'   tidyr::unnest(c(summary_df)) |>               # unnest
#'   dplyr::filter(grepl("^pwy_eff", variable)) |> # get just the pwy:group terms
#'   dplyr::arrange(-abs(mean))                    # sort by decreasing effect size
#'
#' print(pwy_group_res)
#'
#' pwy_group_res |> filter(hit)
#' # ^ Here, there are no hits because we simulated with no dependence
#' }
#'
#' @seealso \code{\link[=plot_pwy_ranef]{plot_pwy_ranef()}} , \code{\link[=plot_pwy_ranef_intervals]{plot_pwy_ranef_intervals()}}, \code{\link[cmdstanr:CmdStanMCMC]{cmdstanr::CmdStanMCMC}}
#' @export
anpan_pwy_ranef_batch = function(bug_pwy_dat,
                                 group_ind = "crc",
                                 out_dir = NULL,
                                 group_exp_rate = 3,
                                 ...) {

  if (!all(c("bug", "pwy", "log10_species_abd", "log10_pwy_abd", group_ind) %in% names(bug_pwy_dat))) {
    stop("The necessary variables are not present in bug_pwy_dat. See ?anpan_pwy_ranef for what the columns should be called.")
  }

  model_path = system.file("stan", "pwy_ranef.stan",
                           package = 'anpan',
                           mustWork = TRUE)

  # Compile it once ahead of time
  pwy_ranef_model = cmdstanr::cmdstan_model(stan_file = model_path,
                                            quiet = TRUE)

  p = progressr::progressor(steps = dplyr::n_distinct(bug_pwy_dat$bug))

  if (!is.data.table(bug_pwy_dat)) {
    message("Converting input to data.table")
    bug_pwy_dat = data.table::as.data.table(bug_pwy_dat)
  }

  if (!is.factor(bug_pwy_dat$bug)) {
    message("Converting bug column to factor")

    bug_pwy_dat = bug_pwy_dat |>
      mutate(bug = factor(bug))
  }


  res = split(x = bug_pwy_dat, by = "bug") |>
    furrr::future_imap(function(.x, .y) {bug_res = safely_anpan_pwy_ranef(bug_pwy_dat = .x,
                                                                          group_ind = group_ind,
                                                                          group_exp_rate = group_exp_rate,
                                                                          ...)
                                         p()
                                         if (is.null(bug_res$error)) {
                                           bug_res$result$bug = .y
                                           bug_res$result = dplyr::relocate(bug_res$result, bug)
                                         }
                                         return(bug_res)},
                       .options = furrr::furrr_options(seed = 123,
                                                       scheduling = Inf))

  res_df = purrr::transpose(res) |>
    tibble::as_tibble()

  errors = res_df |>
    filter(map_lgl(result, is.null))

  if (nrow(errors) > 0) {
    warning("Some bugs failed to fit. See errors.RData in the output directory.")

    if (is.null(out_dir)) {
      warning("Output directory not provided. Errors not saved.")
    } else {
      if (!dir.exists(out_dir)) {
        message("Creating output directory.")
        dir.create(out_dir)
      }
      save(errors,
           file = file.path(out_dir, "errors.RData"))
    }
  }

  pwy_ranef_batch_res = res_df |>
    filter(map_lgl(error, is.null)) |>
    pull(result) |>
    data.table::rbindlist() |>
    tibble::as_tibble()

  if (!is.null(out_dir)) {
    message("Saving results to pwy_ranef_batch_res.RData in the specified output directory.")

    if (!dir.exists(out_dir)) {
      message("Creating output directory.")
      dir.create(out_dir)
    }

    save(pwy_ranef_batch_res,
         file = file.path(out_dir, "pwy_ranef_batch_res.RData"))
  }

  return(pwy_ranef_batch_res)
}

#' Plot a pathway random effects result
#' @inheritParams anpan_pwy_ranef
#' @param pwy_ranef_res a result from \code{\link[=anpan_pwy_ranef]{anpan_pwy_ranef()}} or
#'   \code{\link[=anpan_pwy_ranef_batch]{anpan_pwy_ranef_batch()}}
#' @param ncol The number of columns to use if faceting multiple bugs.
#' @export
plot_pwy_ranef_intervals = function(pwy_ranef_res,
                                    group_ind = 'crc',
                                    ncol = 1) {

  if ("bug" %in% names(pwy_ranef_res)){
    # It's a batch result

    unnest_input = pwy_ranef_res |>
      select(bug, summary_df) |>
      as.data.table()

    unnested = unnest_input[,data.table::rbindlist(summary_df), by = bug] |>
      tibble::as_tibble() # TODO test this works

    plot_input = unnested |>
      filter(grepl("^pwy_eff", variable)) |> # get just the pwy:group terms
      arrange(-abs(mean)) |>                    # sort by decreasing effect size
      mutate(hit_lab = factor(c("non-hit", "hit")[hit + 1],
                              levels = c("hit", "non-hit")),
             pwy = factor(pwy, levels = unique(pwy)))

    facets = facet_wrap("bug", scales = "free", ncol = ncol)

  } else {
    plot_input = pwy_ranef_res |>
      dplyr::pull(summary_df) |>
      dplyr::bind_rows() |>
      filter(grepl("^pwy_eff", variable)) |>
      arrange(-abs(mean)) |>
      mutate(hit_lab = factor(c("non-hit", "hit")[hit + 1],
                              levels = c("hit", "non-hit")),
             pwy = factor(pwy, levels = unique(pwy)))

    facets = NULL
  }

  p = plot_input |>
    ggplot(aes(mean, pwy)) +
    geom_vline(xintercept = 0,
               lty = 2, color = 'grey50') +
    geom_segment(aes(x = q1, xend = q99,
                     yend = pwy)) +
    geom_point(aes(color = hit_lab)) +
    labs(color = NULL,
         y = NULL,
         x = "pwy:group estimate\n(98% posterior intervals)") +
    scale_color_manual(values = c("hit" = "#E41A1C", # brewer set1 but reversed
                                  "non-hit" = "#377EB8")) +
    facets +
    theme_light() +
    theme(strip.text = element_text(color = 'grey20'))

  return(p)
}

#' Plot a pathway random effects result
#' @inheritParams anpan_pwy_ranef
#' @param max_pwy the maximum number of bug:pwy facets to include
#' @param bug_name name of the bug (if using a result from \code{anpan_pwy_ranef()})
#' @param post_draws number of post draws to draw in each facet
#' @param group_labels labels for the 0/1 indicator to use on the plots
#' @param verbose logical for verbosity
#' @details This function plots bug:pwy data alongside posterior draws. If there's a strong group
#'   effect in a particular bug:pwy combination, you will see wide separation of the red and blue
#'   posterior lines.
#'
#'   If \code{bug_name} is specified, \code{bug_pwy_dat} is first filtered to just data from that
#'   bug.
#'
#'   If specified, \code{bug_name} must exactly match the corresponding entries in
#'   \code{bug_pwy_dat}.
#' @export
plot_pwy_ranef = function(bug_pwy_dat,
                          pwy_ranef_res,
                          group_ind = 'crc',
                          group_labels = c("ctrl", "case"),
                          bug_name = NULL,
                          max_pwy = 20,
                          post_draws = 30,
                          verbose = TRUE) {


  if (is.null(bug_name) && !("bug" %in% colnames(pwy_ranef_res))) {
    warning("Couldn't determine the bug name from inputs, setting a placeholder.")
    bug_pwy_dat$bug = "bug"
    pwy_ranef_res$bug = "bug"
  } else if (!is.null(bug_name) && !("bug" %in% colnames(pwy_ranef_res))) {
    pwy_ranef_res$bug = bug_name
    bug_pwy_dat$bug = bug_name
  } else if (!is.null(bug_name) && "bug" %in% colnames(bug_pwy_dat)) {
    if (verbose) message("Filtering bug_pwy_dat to the specified bug.")

    bug_pwy_dat = bug_pwy_dat |>
      filter(bug == bug_name)
  }

  top_pwys = bug_pwy_dat |>
    select(bug, pwy) |>
    unique()

  if (nrow(top_pwys) > max_pwy) {
    if (verbose) message(paste0("Choosing the first ", max_pwy, " bug:pwy combinations in bug_pwy_dat. Subset bug_pwy_dat before calling this function if you'd like to show different bug:pwy combinations."))

    top_pwys = top_pwys |>
      head(n = max_pwy)
  }

  plot_data = bug_pwy_dat |>
    inner_join(top_pwys, by = c("bug", "pwy"))

  plot_data = plot_data |>
    mutate(group_var = group_labels[plot_data[[group_ind]] + 1])

  get_post_draws = function(cmdstan_fit, post_draws = post_draws) {
    cmdstan_fit$draws(format = 'data.frame') |>
      tibble::as_tibble() |>
      dplyr::slice_sample(n = post_draws) |>
      as.data.table() |>
      data.table::melt(id.vars = c(".chain", ".iteration", ".draw"),
                       variable.name = 'variable',
                       value.name = 'value') |>
      tibble::as_tibble() |>
      dplyr::filter(grepl("^pwy_eff|beta|glob|^pwy_interc", variable))
  }

  line_from_iter = function(iter_df) {
    iter_df$slope = iter_df$value[iter_df$variable == "species_beta[1]"]
    iter_df$glob_int = iter_df$value[iter_df$variable == 'global_intercept']
    iter_df |>
      filter(!is.na(pwy)) |>
      group_by(pwy) |>
      summarise(pwy = pwy[1],
                slope = slope[1],
                ctrl = glob_int[1] + value[effects == "pwy_int"],
                case = glob_int[1] + value[effects == "pwy_int"] + value[effects == "pwy_eff"])
  }

  combine_summ_with_draws = function(summary_df, rdraws) {
    summary_df |> select(pwy, variable) |>
      filter(!grepl("sd_|sigma", variable)) |>
      full_join(rdraws, by = 'variable') |>
      mutate(effects = stringr::str_extract(variable, "pwy_eff|pwy_int")) |>
      group_split(`.draw`) |>
      map_dfr(line_from_iter) |>
      as.data.table() |>
      melt(measure.vars = c("case", "ctrl"),
           variable.name = "group_var",
           value.name = "int")
  }

  wrap_char = 40 # Expose to user?

  line_df = pwy_ranef_res |>
    filter(bug %in% unique(plot_data$bug)) |>
    mutate(rdraws = lapply(model_fit,
                           get_post_draws,
                           post_draws = post_draws),
           line_draws = purrr::map2(summary_df, rdraws,
                             combine_summ_with_draws)) |>
    select(bug, line_draws) |>
    as.data.table()

  draw_df = line_df[,data.table::rbindlist(line_draws), by = bug] |>
    tibble::as_tibble() |>
    inner_join(top_pwys, by = c("bug", "pwy")) |>
    dplyr::slice_sample(prop = 1) |>
    mutate(pwy = stringr::str_wrap(pwy, width = wrap_char))

  replace_vector = group_labels
  names(replace_vector) = c("ctrl", "case")

  draw_df$group_var = factor(replace_vector[as.character(draw_df$group_var)],
                             levels = group_labels)

  color_vec = c("#1F78C8", "#ff0000")
  names(color_vec) = group_labels

  plot_data |>
    mutate(pwy = stringr::str_wrap(pwy, width = wrap_char),
           group_var = factor(group_var,
                              levels = group_labels)) |>
    ggplot(aes(log10_species_abd, log10_pwy_abd)) +
    geom_abline(data = draw_df,
                aes(slope = slope,
                    intercept = int,
                    color = group_var),
                alpha = .33) +
    geom_point(aes(color = group_var),
               size = 1) +
    facet_wrap(c("bug", "pwy"), scales = 'free_y') +
    scale_color_manual(values = color_vec) +  # pals::cols25(2) |> dput()
    theme_light() +
    theme(strip.text = element_text(color = 'grey20', margin = margin(.5, .5, .9, .5, unit = 'pt'),
                                    size = 6.5)) +
    labs(color = NULL)
}
biobakery/anpan documentation built on July 26, 2024, 11:19 p.m.