#' Estimate a species-pathway abundance random effects model
#' @description Fit a model of the form log10_pwy_abd ~ log10_species_abd + (1|pwy) + (0+group|pwy) for a single
#' bug
#' @details The priors are as follows: A) student_t(3, mean(log10_pwy_abd), 2.5) on the intercept at the
#' mean species abundance. B) half student_t(3, 0, 2.5) on the residual noise C) 0-centered normal
#' priors on pathway specific effects C1) half student_t(5,0,2.5) on the SD of pathway intercepts
#' C2) half standard normal on the SD of group-specific effects.
#'
#' The pathway index is generated by converting the pwy column to a factor and then to the
#' corresponding integer index.
#'
#' The group_ind column should be numeric with values in {0,1}
#'
#' The main parameter of interest are the elements of the \code{pwy_effects} parameter. The "hit"
#' column is defined by selecting the bug:pwy combinations where 1) 98\% posterior intervals for
#' the pwy:group effect exclude 0, 2) the absolute posterior mean exceeds the specified effect
#' size threshold, and 3) the estimated fixed effect of log10_species_abd on log10_pwy_abd is positive.
#' @param bug_pwy_dat a data frame with a row for each observation and columns "pwy",
#' "log10_species_abd", "log10_pwy_abd", and a group indicator column named according to the
#' \code{group_ind} argument
#' @param group_ind a character giving the name of the column for the 0/1 group indicator variable
#' in \code{bug_pwy_dat}
#' @param group_exp_rate rate parameter of the exponential distribution prior on group effects
#' @param effect_size_threshold effect size threshold for hit-calling pathway:group effects
#' @param ... other arguments to pass to cmdstanr::sample()
#' @returns a list containing the CmdStanMCMC object of the model fit and a summary data frame.
#' @seealso \code{\link[=plot_pwy_ranef]{plot_pwy_ranef()}} , \code{\link[=plot_pwy_ranef_intervals]{plot_pwy_ranef_intervals()}}, \code{\link[cmdstanr:CmdStanMCMC]{cmdstanr::CmdStanMCMC}}
#' @export
anpan_pwy_ranef = function(bug_pwy_dat,
group_ind = "crc",
effect_size_threshold = .25,
group_exp_rate = 3,
...) {
if (!all(c("pwy", "log10_species_abd", "log10_pwy_abd", group_ind) %in% names(bug_pwy_dat))) {
stop("The necessary variables are not present in bug_pwy_dat. See ?anpan_pwy_ranef for what the columns should be called.")
}
if (!is.factor(bug_pwy_dat$pwy)) {
warning("Converting the pwy column to a factor.")
}
model_path = system.file("stan", "pwy_ranef.stan",
package = 'anpan',
mustWork = TRUE)
pwy_ranef_model = cmdstanr::cmdstan_model(stan_file = model_path,
quiet = TRUE)
data_list = list(N = nrow(bug_pwy_dat),
pwy_abd = bug_pwy_dat$log10_pwy_abd,
pwy_mean = mean(bug_pwy_dat$log10_pwy_abd),
intercept_species = model.matrix(~log10_species_abd, data = bug_pwy_dat),
N_pwy = dplyr::n_distinct(bug_pwy_dat$pwy),
pwy_ind = as.integer(factor(bug_pwy_dat$pwy)),
group_ind = bug_pwy_dat[[group_ind]],
group_exp_rate = group_exp_rate)
pwy_ind_map = tibble(index = 1:data_list$N_pwy,
pwy_group_effect = paste("pwy_effects[", index, "]", sep = ""),
pwy_intercept = paste("pwy_intercepts[", index, "]", sep = ""),
pwy = levels(factor(bug_pwy_dat$pwy))) |>
data.table::as.data.table() |>
data.table::melt(id.vars = c("index", "pwy"),
variable.name = 'var_names',
value.name = 'variable') |>
dplyr::select(-var_names) |>
dplyr::arrange(index) |>
data.table::as.data.table()
mod_fit = pwy_ranef_model$sample(data = data_list, ...)
summary_df = mod_fit$draws() |>
posterior::summarise_draws(posterior::default_summary_measures(),
wide = ~purrr::set_names(quantile(.x, probs = c(.01, .99)), c("q1", "q99")),
posterior::default_convergence_measures()) |>
filter(grepl("^pwy|global|sigma|sd_|species_beta", variable)) |> # Discard variables most users won't be interested in like lp, lprior
dplyr::left_join(pwy_ind_map, by = 'variable') |>
mutate(hit = (!(q1 < 0 & q99 > 0)) &
(abs(mean) > effect_size_threshold) &
mean[grepl("species_beta", variable)] > 0) |>
select(pwy, hit, variable:ess_tail)
summary_df$hit[!grepl("^pwy_eff", summary_df$variable)] = NA
return(tibble(model_fit = list(mod_fit),
summary_df = list(summary_df)))
}
safely_anpan_pwy_ranef = purrr::safely(anpan_pwy_ranef)
#' Fit the pathway random effects model for multiple bugs
#' @details In addition to the column requirements of \code{anpan_pwy_ranef()}, the input data frame
#' \code{bug_pwy_dat} here must also contain a variable called "bug" which gives a unique
#' identifier for each bug.
#' @param out_dir output directory
#' @returns a tibble of row-binded anpan_pwy_ranef results
#' @inheritParams anpan_pwy_ranef
#' @examples \dontrun{
#' library(tidyverse)
#' library(anpan)
#'
#' set.seed(123)
#' input_dat = tibble(bug = rep(paste0("bug", 1:5), each = 200),
#' pwy = rep(paste0('pwy', 1:5), times = 200),
#' log10_species_abd = rnorm(1000),
#' log10_pwy_abd = rnorm(1000, mean = .8*log10_species_abd),
#' group = sample(c(0,1), size = 1000, replace = TRUE))
#' # ^ the pathway and and group are NOT related in any pathway.
#'
#' res = anpan_pwy_ranef_batch(input_dat, group_ind = "group")
#'
#' # Examine the summary
#' pwy_group_res = res |>
#' dplyr::select(bug, summary_df) |> # select the two main columns
#' tidyr::unnest(c(summary_df)) |> # unnest
#' dplyr::filter(grepl("^pwy_eff", variable)) |> # get just the pwy:group terms
#' dplyr::arrange(-abs(mean)) # sort by decreasing effect size
#'
#' print(pwy_group_res)
#'
#' pwy_group_res |> filter(hit)
#' # ^ Here, there are no hits because we simulated with no dependence
#' }
#'
#' @seealso \code{\link[=plot_pwy_ranef]{plot_pwy_ranef()}} , \code{\link[=plot_pwy_ranef_intervals]{plot_pwy_ranef_intervals()}}, \code{\link[cmdstanr:CmdStanMCMC]{cmdstanr::CmdStanMCMC}}
#' @export
anpan_pwy_ranef_batch = function(bug_pwy_dat,
group_ind = "crc",
out_dir = NULL,
group_exp_rate = 3,
...) {
if (!all(c("bug", "pwy", "log10_species_abd", "log10_pwy_abd", group_ind) %in% names(bug_pwy_dat))) {
stop("The necessary variables are not present in bug_pwy_dat. See ?anpan_pwy_ranef for what the columns should be called.")
}
model_path = system.file("stan", "pwy_ranef.stan",
package = 'anpan',
mustWork = TRUE)
# Compile it once ahead of time
pwy_ranef_model = cmdstanr::cmdstan_model(stan_file = model_path,
quiet = TRUE)
p = progressr::progressor(steps = dplyr::n_distinct(bug_pwy_dat$bug))
if (!is.data.table(bug_pwy_dat)) {
message("Converting input to data.table")
bug_pwy_dat = data.table::as.data.table(bug_pwy_dat)
}
if (!is.factor(bug_pwy_dat$bug)) {
message("Converting bug column to factor")
bug_pwy_dat = bug_pwy_dat |>
mutate(bug = factor(bug))
}
res = split(x = bug_pwy_dat, by = "bug") |>
furrr::future_imap(function(.x, .y) {bug_res = safely_anpan_pwy_ranef(bug_pwy_dat = .x,
group_ind = group_ind,
group_exp_rate = group_exp_rate,
...)
p()
if (is.null(bug_res$error)) {
bug_res$result$bug = .y
bug_res$result = dplyr::relocate(bug_res$result, bug)
}
return(bug_res)},
.options = furrr::furrr_options(seed = 123,
scheduling = Inf))
res_df = purrr::transpose(res) |>
tibble::as_tibble()
errors = res_df |>
filter(map_lgl(result, is.null))
if (nrow(errors) > 0) {
warning("Some bugs failed to fit. See errors.RData in the output directory.")
if (is.null(out_dir)) {
warning("Output directory not provided. Errors not saved.")
} else {
if (!dir.exists(out_dir)) {
message("Creating output directory.")
dir.create(out_dir)
}
save(errors,
file = file.path(out_dir, "errors.RData"))
}
}
pwy_ranef_batch_res = res_df |>
filter(map_lgl(error, is.null)) |>
pull(result) |>
data.table::rbindlist() |>
tibble::as_tibble()
if (!is.null(out_dir)) {
message("Saving results to pwy_ranef_batch_res.RData in the specified output directory.")
if (!dir.exists(out_dir)) {
message("Creating output directory.")
dir.create(out_dir)
}
save(pwy_ranef_batch_res,
file = file.path(out_dir, "pwy_ranef_batch_res.RData"))
}
return(pwy_ranef_batch_res)
}
#' Plot a pathway random effects result
#' @inheritParams anpan_pwy_ranef
#' @param pwy_ranef_res a result from \code{\link[=anpan_pwy_ranef]{anpan_pwy_ranef()}} or
#' \code{\link[=anpan_pwy_ranef_batch]{anpan_pwy_ranef_batch()}}
#' @param ncol The number of columns to use if faceting multiple bugs.
#' @export
plot_pwy_ranef_intervals = function(pwy_ranef_res,
group_ind = 'crc',
ncol = 1) {
if ("bug" %in% names(pwy_ranef_res)){
# It's a batch result
unnest_input = pwy_ranef_res |>
select(bug, summary_df) |>
as.data.table()
unnested = unnest_input[,data.table::rbindlist(summary_df), by = bug] |>
tibble::as_tibble() # TODO test this works
plot_input = unnested |>
filter(grepl("^pwy_eff", variable)) |> # get just the pwy:group terms
arrange(-abs(mean)) |> # sort by decreasing effect size
mutate(hit_lab = factor(c("non-hit", "hit")[hit + 1],
levels = c("hit", "non-hit")),
pwy = factor(pwy, levels = unique(pwy)))
facets = facet_wrap("bug", scales = "free", ncol = ncol)
} else {
plot_input = pwy_ranef_res |>
dplyr::pull(summary_df) |>
dplyr::bind_rows() |>
filter(grepl("^pwy_eff", variable)) |>
arrange(-abs(mean)) |>
mutate(hit_lab = factor(c("non-hit", "hit")[hit + 1],
levels = c("hit", "non-hit")),
pwy = factor(pwy, levels = unique(pwy)))
facets = NULL
}
p = plot_input |>
ggplot(aes(mean, pwy)) +
geom_vline(xintercept = 0,
lty = 2, color = 'grey50') +
geom_segment(aes(x = q1, xend = q99,
yend = pwy)) +
geom_point(aes(color = hit_lab)) +
labs(color = NULL,
y = NULL,
x = "pwy:group estimate\n(98% posterior intervals)") +
scale_color_manual(values = c("hit" = "#E41A1C", # brewer set1 but reversed
"non-hit" = "#377EB8")) +
facets +
theme_light() +
theme(strip.text = element_text(color = 'grey20'))
return(p)
}
#' Plot a pathway random effects result
#' @inheritParams anpan_pwy_ranef
#' @param max_pwy the maximum number of bug:pwy facets to include
#' @param bug_name name of the bug (if using a result from \code{anpan_pwy_ranef()})
#' @param post_draws number of post draws to draw in each facet
#' @param group_labels labels for the 0/1 indicator to use on the plots
#' @param verbose logical for verbosity
#' @details This function plots bug:pwy data alongside posterior draws. If there's a strong group
#' effect in a particular bug:pwy combination, you will see wide separation of the red and blue
#' posterior lines.
#'
#' If \code{bug_name} is specified, \code{bug_pwy_dat} is first filtered to just data from that
#' bug.
#'
#' If specified, \code{bug_name} must exactly match the corresponding entries in
#' \code{bug_pwy_dat}.
#' @export
plot_pwy_ranef = function(bug_pwy_dat,
pwy_ranef_res,
group_ind = 'crc',
group_labels = c("ctrl", "case"),
bug_name = NULL,
max_pwy = 20,
post_draws = 30,
verbose = TRUE) {
if (is.null(bug_name) && !("bug" %in% colnames(pwy_ranef_res))) {
warning("Couldn't determine the bug name from inputs, setting a placeholder.")
bug_pwy_dat$bug = "bug"
pwy_ranef_res$bug = "bug"
} else if (!is.null(bug_name) && !("bug" %in% colnames(pwy_ranef_res))) {
pwy_ranef_res$bug = bug_name
bug_pwy_dat$bug = bug_name
} else if (!is.null(bug_name) && "bug" %in% colnames(bug_pwy_dat)) {
if (verbose) message("Filtering bug_pwy_dat to the specified bug.")
bug_pwy_dat = bug_pwy_dat |>
filter(bug == bug_name)
}
top_pwys = bug_pwy_dat |>
select(bug, pwy) |>
unique()
if (nrow(top_pwys) > max_pwy) {
if (verbose) message(paste0("Choosing the first ", max_pwy, " bug:pwy combinations in bug_pwy_dat. Subset bug_pwy_dat before calling this function if you'd like to show different bug:pwy combinations."))
top_pwys = top_pwys |>
head(n = max_pwy)
}
plot_data = bug_pwy_dat |>
inner_join(top_pwys, by = c("bug", "pwy"))
plot_data = plot_data |>
mutate(group_var = group_labels[plot_data[[group_ind]] + 1])
get_post_draws = function(cmdstan_fit, post_draws = post_draws) {
cmdstan_fit$draws(format = 'data.frame') |>
tibble::as_tibble() |>
dplyr::slice_sample(n = post_draws) |>
as.data.table() |>
data.table::melt(id.vars = c(".chain", ".iteration", ".draw"),
variable.name = 'variable',
value.name = 'value') |>
tibble::as_tibble() |>
dplyr::filter(grepl("^pwy_eff|beta|glob|^pwy_interc", variable))
}
line_from_iter = function(iter_df) {
iter_df$slope = iter_df$value[iter_df$variable == "species_beta[1]"]
iter_df$glob_int = iter_df$value[iter_df$variable == 'global_intercept']
iter_df |>
filter(!is.na(pwy)) |>
group_by(pwy) |>
summarise(pwy = pwy[1],
slope = slope[1],
ctrl = glob_int[1] + value[effects == "pwy_int"],
case = glob_int[1] + value[effects == "pwy_int"] + value[effects == "pwy_eff"])
}
combine_summ_with_draws = function(summary_df, rdraws) {
summary_df |> select(pwy, variable) |>
filter(!grepl("sd_|sigma", variable)) |>
full_join(rdraws, by = 'variable') |>
mutate(effects = stringr::str_extract(variable, "pwy_eff|pwy_int")) |>
group_split(`.draw`) |>
map_dfr(line_from_iter) |>
as.data.table() |>
melt(measure.vars = c("case", "ctrl"),
variable.name = "group_var",
value.name = "int")
}
wrap_char = 40 # Expose to user?
line_df = pwy_ranef_res |>
filter(bug %in% unique(plot_data$bug)) |>
mutate(rdraws = lapply(model_fit,
get_post_draws,
post_draws = post_draws),
line_draws = purrr::map2(summary_df, rdraws,
combine_summ_with_draws)) |>
select(bug, line_draws) |>
as.data.table()
draw_df = line_df[,data.table::rbindlist(line_draws), by = bug] |>
tibble::as_tibble() |>
inner_join(top_pwys, by = c("bug", "pwy")) |>
dplyr::slice_sample(prop = 1) |>
mutate(pwy = stringr::str_wrap(pwy, width = wrap_char))
replace_vector = group_labels
names(replace_vector) = c("ctrl", "case")
draw_df$group_var = factor(replace_vector[as.character(draw_df$group_var)],
levels = group_labels)
color_vec = c("#1F78C8", "#ff0000")
names(color_vec) = group_labels
plot_data |>
mutate(pwy = stringr::str_wrap(pwy, width = wrap_char),
group_var = factor(group_var,
levels = group_labels)) |>
ggplot(aes(log10_species_abd, log10_pwy_abd)) +
geom_abline(data = draw_df,
aes(slope = slope,
intercept = int,
color = group_var),
alpha = .33) +
geom_point(aes(color = group_var),
size = 1) +
facet_wrap(c("bug", "pwy"), scales = 'free_y') +
scale_color_manual(values = color_vec) + # pals::cols25(2) |> dput()
theme_light() +
theme(strip.text = element_text(color = 'grey20', margin = margin(.5, .5, .9, .5, unit = 'pt'),
size = 6.5)) +
labs(color = NULL)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.