# Define global variable
sccomp_stan_models_cache_dir = file.path(path.expand("~"), ".sccomp_models", packageVersion("sccomp"))
# Greater than
gt = function(a, b){ a > b }
# Smaller than
st = function(a, b){ a < b }
# Negation
not = function(is){ !is }
#' Add attribute to abject
#'
#' @keywords internal
#' @noRd
#'
#'
#' @param var A tibble
#' @param attribute An object
#' @param name A character name of the attribute
#'
#' @return A tibble with an additional attribute
add_attr = function(var, attribute, name) {
attr(var, name) <- attribute
var
}
#' Formula parser
#'
#' @param fm A formula
#'
#' @importFrom stringr str_subset
#' @importFrom magrittr extract2
#' @importFrom stats terms
#'
#' @return A character vector
#'
#' @keywords internal
#' @noRd
parse_formula <- function(fm) {
stopifnot("The formula must be of the kind \"~ factors\" " = attr(terms(fm), "response") == 0)
as.character(attr(terms(fm), "variables")) |>
str_subset("\\|", negate = TRUE) %>%
# Does not work the following
# |>
# extract2(-1)
.[-1]
}
#' Formula parser
#'
#' @param fm A formula
#'
#' @importFrom stringr str_subset
#' @importFrom stringr str_split
#' @importFrom stringr str_remove_all
#' @importFrom rlang set_names
#' @importFrom purrr map_dfr
#' @importFrom stringr str_trim
#'
#' @importFrom magrittr extract2
#'
#' @return A character vector
#'
#' @keywords internal
#' @noRd
formula_to_random_effect_formulae <- function(fm) {
# Define the variables as NULL to avoid CRAN NOTES
formula <- NULL
stopifnot("The formula must be of the kind \"~ factors\" " = attr(terms(fm), "response") == 0)
random_effect_elements =
as.character(attr(terms(fm), "variables")) |>
# Select random intercept part
str_subset("\\|")
if(length(random_effect_elements) > 0){
random_effect_elements |>
# Divide grouping from factors
str_split("\\|") |>
# Set name
map_dfr(~ .x |> set_names(c("formula", "grouping"))) |>
# Create formula
mutate(formula = map(formula, ~ formula(glue("~ {.x}")))) |>
mutate(grouping = grouping |> str_trim())
}
else
tibble(`formula` = list(), grouping = character())
}
#' Formula parser
#'
#' @param fm A formula
#'
#' @importFrom stringr str_subset
#' @importFrom stringr str_split
#' @importFrom stringr str_remove_all
#' @importFrom rlang set_names
#' @importFrom purrr map_dfr
#'
#' @importFrom magrittr extract2
#'
#' @return A character vector
#'
#' @keywords internal
#' @noRd
parse_formula_random_effect <- function(fm) {
# Define the variables as NULL to avoid CRAN NOTES
formula <- NULL
stopifnot("The formula must be of the kind \"~ factors\" " = attr(terms(fm), "response") == 0)
random_effect_elements =
as.character(attr(terms(fm), "variables")) |>
# Select random intercept part
str_subset("\\|")
if(length(random_effect_elements) > 0){
formula_to_random_effect_formulae(fm) |>
# Divide factors
mutate(factor = map(
formula,
~
# Attach intercept
.x |>
terms() |>
attr("intercept") |>
str_replace("^1$", "(Intercept)") |>
str_subset("0", negate = TRUE) |>
# Attach variables
c(
.x |>
terms() |>
attr("variables") |>
as.character() |>
str_split("\\+") |>
as.character() %>%
.[-1]
)
)) |>
unnest(factor)
}
else
tibble(factor = character(), grouping = character())
}
#' Get matrix from tibble
#'
#' @import dplyr
#'
#' @keywords internal
#' @noRd
#'
#' @import dplyr
#' @importFrom purrr as_mapper
#'
#' @param .x A tibble
#' @param .p A boolean
#' @param .f1 A function
#' @param .f2 A function
#'
#' @return A tibble
ifelse_pipe = function(.x, .p, .f1, .f2 = NULL) {
switch(.p %>% `!` %>% sum(1),
as_mapper(.f1)(.x),
if (.f2 %>% is.null %>% `!`)
as_mapper(.f2)(.x)
else
.x)
}
#' @importFrom tidyr gather
#' @importFrom magrittr set_rownames
#'
#' @keywords internal
#' @noRd
#'
#' @param tbl A tibble
#' @param rownames A character string of the rownames
#'
#' @return A matrix
as_matrix <- function(tbl, rownames = NULL) {
# Define the variables as NULL to avoid CRAN NOTES
variable <- NULL
tbl %>%
ifelse_pipe(
tbl %>%
ifelse_pipe(!is.null(rownames), ~ .x %>% dplyr::select(-contains(rownames))) %>%
summarise_all(class) %>%
gather(variable, class) %>%
pull(class) %>%
unique() %>%
`%in%`(c("numeric", "integer")) %>% not() %>% any(),
~ {
warning("to_matrix says: there are NON-numerical columns, the matrix will NOT be numerical")
.x
}
) %>%
as.data.frame() %>%
# Deal with rownames column if present
ifelse_pipe(!is.null(rownames),
~ .x %>%
set_rownames(tbl %>% pull(!!rownames)) %>%
select(-!!rownames)) %>%
# Convert to matrix
as.matrix()
}
#' vb_iterative
#'
#' @description Runs iteratively variational bayes until it suceeds
#'
#'
#' @keywords internal
#' @noRd
#'
#' @param model A Stan model
#' @param output_samples An integer of how many samples from posteriors
#' @param iter An integer of how many max iterations
#' @param tol_rel_obj A real
#' @param additional_parameters_to_save A character vector
#' @param data A data frame
#' @param seed An integer
#' @param ... List of paramaters for vb function of Stan
#'
#' @return A Stan fit object
#'
vb_iterative = function(model,
output_samples,
iter,
tol_rel_obj,
additional_parameters_to_save = c(),
data,
output_dir = output_dir,
seed,
init = "random",
inference_method,
cores = 1,
verbose = TRUE,
psis_resample = FALSE,
...) {
res = NULL
i = 0
while (is.null(res) & i < 5) {
res = tryCatch({
if(inference_method=="pathfinder")
my_res = model |>
sample_safe(
pathfinder_fx,
data = data,
tol_rel_obj = tol_rel_obj,
output_dir = output_dir,
seed = seed+i,
# init = init,
num_paths=50,
num_threads = cores,
single_path_draws = output_samples / 50 ,
max_lbfgs_iters=100,
history_size = 100,
show_messages = verbose,
psis_resample = psis_resample,
...
)
else if(inference_method=="variational")
my_res = model |>
sample_safe(
variational_fx,
data = data,
output_samples = output_samples,
iter = iter,
tol_rel_obj = tol_rel_obj,
output_dir = output_dir,
seed = seed+i,
init = init,
show_messages = verbose,
threads = cores,
...
)
boolFalse <- TRUE
return(my_res)
},
error = function(e) {
writeLines(sprintf("Further attempt with Variational Bayes: %s", e))
return(NULL)
},
finally = {
})
i = i + 1
}
if(is.null(res)) stop(sprintf("sccomp says: variational Bayes did not converge after %s attempts. Please use variational_inference = FALSE for a HMC fitting.", i))
return(res)
}
#' draws_to_tibble_x_y
#'
#' @importFrom tidyr pivot_longer
#' @importFrom rlang :=
#'
#' @param fit A fit object
#' @param par A character vector. The parameters to extract.
#' @param x A character. The first index.
#' @param y A character. The first index.
#'
#' @keywords internal
#' @noRd
draws_to_tibble_x_y = function(fit, par, x, y, number_of_draws = NULL) {
# Define the variables as NULL to avoid CRAN NOTES
dummy <- NULL
.variable <- NULL
.chain <- NULL
.iteration <- NULL
.draw <- NULL
.value <- NULL
par_names =
fit$metadata()$stan_variables %>% grep(sprintf("%s", par), ., value = TRUE)
fit$draws(variables = par, format = "draws_df") %>%
mutate(.iteration = seq_len(n())) %>%
pivot_longer(
names_to = "parameter", # c( ".chain", ".variable", x, y),
cols = contains(par),
#names_sep = "\\.?|\\[|,|\\]|:",
# names_ptypes = list(
# ".variable" = character()),
values_to = ".value"
) %>%
tidyr::extract(parameter, c(".chain", ".variable", x, y), "([1-9]+)?\\.?([a-zA-Z0-9_\\.]+)\\[([0-9]+),([0-9]+)") |>
# Warning message:
# Expected 5 pieces. Additional pieces discarded
suppressWarnings() %>%
mutate(
!!as.symbol(x) := as.integer(!!as.symbol(x)),
!!as.symbol(y) := as.integer(!!as.symbol(y))
) %>%
arrange(.variable, !!as.symbol(x), !!as.symbol(y), .chain) %>%
group_by(.variable, !!as.symbol(x), !!as.symbol(y)) %>%
mutate(.draw = seq_len(n())) %>%
ungroup() %>%
select(!!as.symbol(x), !!as.symbol(y), .chain, .iteration, .draw ,.variable , .value) %>%
filter(.variable == par)
}
#' @importFrom tidyr separate
#' @importFrom purrr when
#'
#' @param fit A fit object from a statistical model, from the 'rstan' package.
#' @param par A character vector specifying the parameters to extract from the fit object.
#' @param x A character string specifying the first index in the parameter names.
#' @param y A character string specifying the second index in the parameter names (optional).
#' @param probs A numerical vector specifying the quantiles to extract.
#'
#' @keywords internal
#' @noRd
summary_to_tibble = function(fit, par, x, y = NULL, probs = c(0.025, 0.25, 0.50, 0.75, 0.975)) {
# Extract parameter names from the fit object that match the 'par' argument
par_names = names(fit) %>% grep(sprintf("%s", par), ., value = TRUE)
# Avoid bug
#if(fit@stan_args[[1]]$method %>% is.null) fit@stan_args[[1]]$method = "hmc"
summary =
fit$summary(variables = par, "mean", ~quantile(.x, probs = probs, na.rm=TRUE)) %>%
rename(.variable = variable ) %>%
when(
is.null(y) ~ (.) %>% tidyr::separate(col = .variable, into = c(".variable", x, y), sep="\\[|,|\\]", convert = TRUE, extra="drop"),
~ (.) %>% tidyr::separate(col = .variable, into = c(".variable", x, y), sep="\\[|,|\\]", convert = TRUE, extra="drop")
)
# summaries are returned only for HMC
if(!"n_eff" %in% colnames(summary)) summary = summary |> mutate(n_eff = NA)
if(!"R_k_hat" %in% colnames(summary)) summary = summary |> mutate(R_k_hat = NA)
summary
}
#' @importFrom rlang :=
label_deleterious_outliers = function(.my_data){
# Define the variables as NULL to avoid CRAN NOTES
.count <- NULL
`95%` <- NULL
`5%` <- NULL
X <- NULL
iteration <- NULL
outlier_above <- NULL
slope <- NULL
is_group_right <- NULL
outlier_below <- NULL
.my_data %>%
# join CI
mutate(outlier_above = !!.count > `95%`) %>%
mutate(outlier_below = !!.count < `5%`) %>%
# Mark if on the right of the factor scale
mutate(is_group_right = !!as.symbol(colnames(X)[2]) > mean( !!as.symbol(colnames(X)[2]) )) %>%
# Check if outlier might be deleterious for the statistics
mutate(
!!as.symbol(sprintf("deleterious_outlier_%s", iteration)) :=
(outlier_above & slope > 0 & is_group_right) |
(outlier_below & slope > 0 & !is_group_right) |
(outlier_above & slope < 0 & !is_group_right) |
(outlier_below & slope < 0 & is_group_right)
) %>%
select(-outlier_above, -outlier_below, -is_group_right)
}
#' @importFrom readr write_file
fit_model = function(
data_for_model, model_name, censoring_iteration = 1, cores = detectCores(), quantile = 0.95,
warmup_samples = 300, approximate_posterior_inference = NULL, inference_method, verbose = TRUE,
seed , pars = c("beta", "alpha", "prec_coeff","prec_sd"), output_samples = NULL, chains=NULL, max_sampling_iterations = 20000,
output_directory = "sccomp_draws_files",
...
)
{
# # if analysis approximated
# # If posterior analysis is approximated I just need enough
# how_many_posterior_draws_practical = ifelse(approximate_posterior_analysis, 1000, how_many_posterior_draws)
# additional_parameters_to_save = additional_parameters_to_save %>% c("lambda_log_param", "sigma_raw") %>% unique
# Find number of draws
draws_supporting_quantile = 50
if(is.null(output_samples)){
output_samples =
(draws_supporting_quantile/((1-quantile)/2)) %>% # /2 because I have two tails
max(4000)
if(output_samples > max_sampling_iterations) {
# message("sccomp says: the number of draws used to defined quantiles of the posterior distribution is capped to 20K.") # This means that for very low probability threshold the quantile could become unreliable. We suggest to limit the probability threshold between 0.1 and 0.01")
output_samples = max_sampling_iterations
}}
# Find optimal number of chains
if(is.null(chains))
chains =
find_optimal_number_of_chains(
how_many_posterior_draws = output_samples,
warmup = warmup_samples,
parallelisation_start_penalty = 100
) %>%
min(cores)
# chains = 3
init_list=list(
prec_coeff = c(5,0),
prec_sd = 1,
alpha = matrix(c(rep(5, data_for_model$M), rep(0, (data_for_model$A-1) *data_for_model$M)), nrow = data_for_model$A, byrow = TRUE),
beta_raw_raw = matrix(0, data_for_model$C , data_for_model$M-1) ,
mix_p = 0.1
)
if(data_for_model$n_random_eff>0){
init_list$random_effect_raw = matrix(0, data_for_model$ncol_X_random_eff[1] , data_for_model$M-1)
init_list$random_effect_sigma_raw = matrix(0, data_for_model$M-1 , data_for_model$how_many_factors_in_random_design[1])
init_list$sigma_correlation_factor = array(0, dim = c(
data_for_model$M-1,
data_for_model$how_many_factors_in_random_design[1],
data_for_model$how_many_factors_in_random_design[1]
))
init_list$random_effect_sigma_mu = 0.5 |> as.array()
init_list$random_effect_sigma_sigma = 0.2 |> as.array()
init_list$zero_random_effect = rep(0, size = 1) |> as.array()
}
if(data_for_model$n_random_eff>1){
init_list$random_effect_raw_2 = matrix(0, data_for_model$ncol_X_random_eff[2] , data_for_model$M-1)
init_list$random_effect_sigma_raw_2 = matrix(0, data_for_model$M-1 , data_for_model$how_many_factors_in_random_design[2])
init_list$sigma_correlation_factor_2 = array(0, dim = c(
data_for_model$M-1,
data_for_model$how_many_factors_in_random_design[2],
data_for_model$how_many_factors_in_random_design[2]
))
}
init = map(1:chains, ~ init_list) %>%
setNames(as.character(1:chains))
#output_directory = "sccomp_draws_files"
dir.create(output_directory, showWarnings = FALSE)
# Fit
mod = load_model(model_name, threads = cores)
# Avoid 0 proportions
if(data_for_model$is_proportion && min(data_for_model$y_proportion)==0){
warning("sccomp says: your proportion values include 0. Assuming that 0s derive from a precision threshold (e.g. deconvolution), 0s are converted to the smaller non 0 proportion value.")
data_for_model$y_proportion[data_for_model$y_proportion==0] =
min(data_for_model$y_proportion[data_for_model$y_proportion>0])
}
if(inference_method == "hmc"){
tryCatch({
mod |> sample_safe(
sample_fx,
data = data_for_model ,
chains = chains,
parallel_chains = chains,
threads_per_chain = ceiling(cores/chains),
iter_warmup = warmup_samples,
iter_sampling = as.integer(output_samples /chains),
#refresh = ifelse(verbose, 1000, 0),
seed = seed,
save_warmup = FALSE,
init = init,
output_dir = output_directory,
show_messages = verbose,
...
) |>
suppressWarnings()
},
error = function(e) {
# I don't know why thi is needed nd why the model sometimes is not compliled correctly
if(e |> as.character() |> str_detect("Model not compiled"))
model = load_model(model_name, force=TRUE, threads = cores)
else
stop()
})
} else{
if(inference_method=="pathfinder") init = pf
else if(inference_method=="variational") init = list(init_list)
vb_iterative(
mod,
output_samples = output_samples ,
iter = 10000,
tol_rel_obj = 0.01,
data = data_for_model, refresh = ifelse(verbose, 1000, 0),
seed = seed,
output_dir = output_directory,
init = init,
inference_method = inference_method,
cores = cores,
psis_resample = FALSE,
verbose = verbose
) %>%
suppressWarnings()
}
}
#' @importFrom purrr map2_lgl
#' @importFrom tidyr pivot_wider
#' @importFrom rlang :=
#'
#' @keywords internal
#' @noRd
parse_fit = function(data_for_model, fit, censoring_iteration = 1, chains){
# Define the variables as NULL to avoid CRAN NOTES
M <- NULL
fit %>%
draws_to_tibble_x_y("beta", "C", "M") %>%
left_join(tibble(C=seq_len(ncol(data_for_model$X)), C_name = colnames(data_for_model$X)), by = "C") %>%
nest(!!as.symbol(sprintf("beta_posterior_%s", censoring_iteration)) := -M)
}
#' @importFrom purrr map2_lgl
#' @importFrom tidyr pivot_wider
#' @importFrom tidyr spread
#' @importFrom stats C
#' @importFrom rlang :=
#' @importFrom tibble enframe
#'
#' @keywords internal
#' @noRd
beta_to_CI = function(fitted, censoring_iteration = 1, false_positive_rate, factor_of_interest){
# Define the variables as NULL to avoid CRAN NOTES
M <- NULL
C_name <- NULL
.lower <- NULL
.median <- NULL
.upper <- NULL
effect_column_name = sprintf("composition_effect_%s", factor_of_interest) %>% as.symbol()
CI = fitted %>%
unnest(!!as.symbol(sprintf("beta_posterior_%s", censoring_iteration))) %>%
nest(data = -c(M, C, C_name)) %>%
# Attach beta
mutate(!!as.symbol(sprintf("beta_quantiles_%s", censoring_iteration)) := map(
data,
~ quantile(
.x$.value,
probs = c(false_positive_rate/2, 0.5, 1-(false_positive_rate/2))
) %>%
enframe() %>%
mutate(name = c(".lower", ".median", ".upper")) %>%
spread(name, value)
)) %>%
unnest(!!as.symbol(sprintf("beta_quantiles_%s", censoring_iteration))) %>%
select(-data, -C) %>%
pivot_wider(names_from = C_name, values_from=c(.lower , .median , .upper))
# Create main effect if exists
if(!is.na(factor_of_interest) )
CI |>
mutate(!!effect_column_name := !!as.symbol(sprintf(".median_%s", factor_of_interest))) %>%
nest(composition_CI = -c(M, !!effect_column_name))
else
CI |> nest(composition_CI = -c(M))
}
#' @importFrom purrr map2_lgl
#' @importFrom tidyr pivot_wider
#' @importFrom stats C
#' @importFrom rlang :=
#'
#' @keywords internal
#' @noRd
alpha_to_CI = function(fitted, censoring_iteration = 1, false_positive_rate, factor_of_interest){
# Define the variables as NULL to avoid CRAN NOTES
M <- NULL
C_name <- NULL
.lower <- NULL
.median <- NULL
.upper <- NULL
effect_column_name = sprintf("variability_effect_%s", factor_of_interest) %>% as.symbol()
fitted %>%
unnest(!!as.symbol(sprintf("alpha_%s", censoring_iteration))) %>%
nest(data = -c(M, C, C_name)) %>%
# Attach beta
mutate(!!as.symbol(sprintf("alpha_quantiles_%s", censoring_iteration)) := map(
data,
~ quantile(
.x$.value,
probs = c(false_positive_rate/2, 0.5, 1-(false_positive_rate/2))
) %>%
enframe() %>%
mutate(name = c(".lower", ".median", ".upper")) %>%
spread(name, value)
)) %>%
unnest(!!as.symbol(sprintf("alpha_quantiles_%s", censoring_iteration))) %>%
select(-data, -C) %>%
pivot_wider(names_from = C_name, values_from=c(.lower , .median , .upper)) %>%
mutate(!!effect_column_name := !!as.symbol(sprintf(".median_%s", factor_of_interest))) %>%
nest(variability_CI = -c(M, !!effect_column_name))
}
#' Get Random Intercept Design 2
#'
#' This function processes the formula composition elements in the data and creates design matrices
#' for random intercept models.
#'
#' @param .data_ A data frame containing the data.
#' @param .sample A quosure representing the sample variable.
#' @param formula_composition A data frame containing the formula composition elements.
#'
#' @return A data frame with the processed design matrices for random intercept models.
#'
#' @importFrom glue glue
#' @importFrom magrittr subtract
#' @importFrom purrr map2
#' @importFrom dplyr mutate
#' @importFrom dplyr select
#' @importFrom dplyr pull
#' @importFrom dplyr filter
#' @importFrom dplyr left_join
#' @importFrom dplyr mutate_all
#' @importFrom dplyr mutate_if
#' @importFrom dplyr as_tibble
#' @importFrom tidyr pivot_longer
#' @importFrom rlang enquo
#' @importFrom rlang quo_name
#' @importFrom tidyselect all_of
#' @importFrom readr type_convert
#' @noRd
get_random_effect_design2 = function(.data_, .sample, formula_composition ){
# Define the variables as NULL to avoid CRAN NOTES
formula <- NULL
.sample = enquo(.sample)
grouping_table =
formula_composition |>
formula_to_random_effect_formulae() |>
mutate(design = map2(
formula, grouping,
~ {
mydesign = .data_ |> get_design_matrix(.x, !!.sample)
mydesigncol_X_random_eff = .data_ |> select(all_of(.y)) |> pull(1) |> rep(ncol(mydesign)) |> matrix(ncol = ncol(mydesign))
mydesigncol_X_random_eff[mydesign==0L] = NA
colnames(mydesigncol_X_random_eff) = colnames(mydesign)
rownames(mydesigncol_X_random_eff) = rownames(mydesign)
mydesigncol_X_random_eff |>
as_tibble(rownames = quo_name(.sample)) |>
pivot_longer(-!!.sample, names_to = "factor", values_to = "grouping") |>
filter(!is.na(grouping)) |>
mutate("mean_idx" = glue("{factor}___{grouping}") |> as.factor() |> as.integer() )|>
with_groups(factor, ~ ..1 |> mutate(mean_idx = if_else(mean_idx == max(mean_idx), 0L, mean_idx))) |>
mutate(minus_sum = if_else(mean_idx==0, factor |> as.factor() |> as.integer(), 0L)) |>
# Make right rank
mutate(mean_idx = mean_idx |> as.factor() |> as.integer() |> subtract(1)) |>
# drop minus_sum if we just have one grouping per factor
with_groups(factor, ~ {
if(length(unique(..1$grouping)) == 1) ..1 |> mutate(., minus_sum = 0)
else ..1
}) |>
# Add value
left_join(
mydesign |>
as_tibble(rownames = quo_name(.sample)) |>
mutate_all(as.character) |>
readr::type_convert(guess_integer = TRUE ) |>
suppressMessages() |>
mutate_if(is.integer, ~1) |>
pivot_longer(-!!.sample, names_to = "factor"),
by = join_by(!!.sample, factor)
) |>
# Create unique name
mutate(group___label = glue("{factor}___{grouping}")) |>
mutate(group___numeric = group___label |> as.factor() |> as.integer()) |>
mutate(factor___numeric = `factor` |> as.factor() |> as.integer())
}))
}
#' Get Random Intercept Design
#'
#' This function processes random intercept elements in the data and creates design matrices
#' for random intercept models.
#'
#' @param .data_ A data frame containing the data.
#' @param .sample A quosure representing the sample variable.
#' @param random_effect_elements A data frame containing the random intercept elements.
#'
#' @return A data frame with the processed design matrices for random intercept models.
#'
#' @importFrom glue glue
#' @importFrom magrittr subtract
#' @importFrom dplyr select
#' @importFrom dplyr mutate
#' @importFrom dplyr pull
#' @importFrom dplyr if_else
#' @importFrom dplyr distinct
#' @importFrom dplyr group_by
#' @importFrom dplyr summarise
#' @importFrom rlang set_names
#' @importFrom purrr map_lgl
#' @importFrom purrr pmap
#' @importFrom purrr map_int
#' @importFrom dplyr with_groups
#' @importFrom rlang enquo
#' @importFrom rlang quo_name
#' @importFrom tidyselect all_of
#' @noRd
get_random_effect_design = function(.data_, .sample, random_effect_elements ){
# Define the variables as NULL to avoid CRAN NOTES
is_factor_continuous <- NULL
design <- NULL
max_mean_idx <- NULL
max_minus_sum <- NULL
max_factor_numeric <- NULL
max_group_numeric <- NULL
min_mean_idx <- NULL
min_minus_sum <- NULL
.sample = enquo(.sample)
# If intercept is not defined create it
if(nrow(random_effect_elements) == 0 )
return(
random_effect_elements |>
mutate(
design = list(),
is_factor_continuous = logical()
)
)
# Otherwise process
random_effect_elements |>
mutate(is_factor_continuous = map_lgl(
`factor`,
~ .x != "(Intercept)" && .data_ |> select(all_of(.x)) |> pull(1) |> is("numeric")
)) |>
mutate(design = pmap(
list(grouping, `factor`, is_factor_continuous),
~ {
# Make exception for random intercept
if(..2 == "(Intercept)")
.data_ = .data_ |> mutate(`(Intercept)` = 1)
.data_ =
.data_ |>
select(!!.sample, ..1, ..2) |>
set_names(c(quo_name(.sample), "group___", "factor___")) |>
mutate(group___numeric = group___, factor___numeric = factor___) |>
mutate(group___label := glue("{group___}___{.y}")) |>
mutate(factor___ = ..2)
# If factor is continuous
if(..3)
.data_ %>%
# Mutate random intercept grouping to number
mutate(group___numeric = factor(group___numeric) |> as.integer()) |>
# If intercept is not defined create it
mutate(., factor___numeric = 1L) |>
# If categorical make sure the group is independent for factors
mutate(mean_idx = glue("{group___numeric}") |> as.factor() |> as.integer()) |>
mutate(mean_idx = if_else(mean_idx == max(mean_idx), 0L, mean_idx)) |>
mutate(mean_idx = as.factor(mean_idx) |> as.integer() |> subtract(1L)) |>
mutate(minus_sum = if_else(mean_idx==0, 1L, 0L))
#|>
# distinct()
# If factor is discrete
else
.data_ %>%
# Mutate random intercept grouping to number
mutate(group___numeric = factor(group___numeric) |> as.integer()) |>
# If categorical make sure the group is independent for factors
mutate(mean_idx = glue("{factor___numeric}{group___numeric}") |> as.factor() |> as.integer()) |>
with_groups(factor___numeric, ~ ..1 |> mutate(mean_idx = if_else(mean_idx == max(mean_idx), 0L, mean_idx))) |>
mutate(mean_idx = as.factor(mean_idx) |> as.integer() |> subtract(1L)) |>
mutate(minus_sum = if_else(mean_idx==0, as.factor(factor___numeric) |> as.integer(), 0L)) |>
# drop minus_sum if we just have one group___numeric per factor
with_groups(factor___numeric, ~ {
if(length(unique(..1$group___numeric)) == 1) ..1 |> mutate(., minus_sum = 0)
else ..1
}) |>
mutate(factor___numeric = as.factor(factor___numeric) |> as.integer())
#|>
# distinct()
}
)) |>
# Make indexes unique across parameters
mutate(
max_mean_idx = map_int(design, ~ ..1 |> pull(mean_idx) |> max()),
max_minus_sum = map_int(design, ~ ..1 |> pull(minus_sum) |> max()),
max_factor_numeric = map_int(design, ~ ..1 |> pull(factor___numeric) |> max()),
max_group_numeric = map_int(design, ~ ..1 |> pull(group___numeric) |> max())
) |>
mutate(
min_mean_idx = cumsum(max_mean_idx) - max_mean_idx ,
min_minus_sum = cumsum(max_minus_sum) - max_minus_sum,
max_factor_numeric = cumsum(max_factor_numeric) - max_factor_numeric,
max_group_numeric = cumsum(max_group_numeric) - max_group_numeric
) |>
mutate(design = pmap(
list(design, min_mean_idx, min_minus_sum, max_factor_numeric, max_group_numeric),
~ ..1 |>
mutate(
mean_idx = if_else(mean_idx>0, mean_idx + ..2, mean_idx),
minus_sum = if_else(minus_sum>0, minus_sum + ..3, minus_sum),
factor___numeric = factor___numeric + ..4,
group___numeric = group___numeric + ..5
)
))
}
#' @importFrom glue glue
#' @noRd
get_design_matrix = function(.data_spread, formula, .sample){
.sample = enquo(.sample)
design_matrix =
.data_spread %>%
select(!!.sample, parse_formula(formula)) |>
mutate(across(where(is.numeric), scale)) |>
model.matrix(formula, data=_)
rownames(design_matrix) = .data_spread |> pull(!!.sample)
design_matrix
}
#' Check Random Intercept Design
#'
#' This function checks the validity of the random intercept design in the data.
#'
#' @param .data A data frame containing the data.
#' @param factor_names A character vector of factor names.
#' @param random_effect_elements A data frame containing the random intercept elements.
#' @param formula The formula used for the model.
#' @param X The design matrix.
#'
#' @return A data frame with the checked random intercept elements.
#'
#' @importFrom tidyr nest
#' @importFrom dplyr mutate
#' @importFrom dplyr select
#' @importFrom dplyr pull
#' @importFrom dplyr filter
#' @importFrom dplyr distinct
#' @importFrom rlang set_names
#' @importFrom tidyr unite
#' @importFrom purrr map2
#' @importFrom stringr str_subset
#' @importFrom readr type_convert
#' @noRd
check_random_effect_design = function(.data, factor_names, random_effect_elements, formula, X){
# Define the variables as NULL to avoid CRAN NOTES
factors <- NULL
groupings <- NULL
.data_ = .data
# Loop across groupings
random_effect_elements |>
nest(factors = `factor` ) |>
mutate(checked = map2(
grouping, factors,
~ {
.y = unlist(.y)
# Check that the group column is categorical
stopifnot("sccomp says: the grouping column should be categorical (not numeric)" =
.data_ |>
select(all_of(.x)) |>
pull(1) |>
class() %in%
c("factor", "logical", "character")
)
# # Check sanity of the grouping if only random intercept
# stopifnot(
# "sccomp says: the random intercept completely confounded with one or more discrete factors" =
# !(
# !.y |> equals("(Intercept)") &&
# .data_ |> select(any_of(.y)) |> suppressWarnings() |> pull(1) |> class() %in% c("factor", "character") |> any() &&
# .data_ |>
# select(.x, any_of(.y)) |>
# select_if(\(x) is.character(x) | is.factor(x) | is.logical(x)) |>
# distinct() %>%
#
# # TEMPORARY FIX
# set_names(c(colnames(.)[1], 'factor___temp')) |>
#
# count(factor___temp) |>
# pull(n) |>
# equals(1) |>
# any()
# )
# )
# # Check if random intercept with random continuous slope. At the moment is not possible
# # Because it would require I believe a multivariate prior
# stopifnot(
# "sccomp says: continuous random slope is not supported yet" =
# !(
# .y |> str_subset("1", negate = TRUE) |> length() |> gt(0) &&
# .data_ |>
# select(
# .y |> str_subset("1", negate = TRUE)
# ) |>
# map_chr(class) %in%
# c("integer", "numeric")
# )
# )
# Check if random intercept with random continuous slope. At the moment is not possible
# Because it would require I believe a multivariate prior
stopifnot(
"sccomp says: currently, discrete random slope is only supported in a intercept-free model. For example ~ 0 + treatment + (treatment | group)" =
!(
# If I have both random intercept and random discrete slope
.y |> equals("(Intercept)") |> any() &&
length(.y) > 1 &&
# If I have random slope and non-intercept-free model
.data_ |> select(any_of(.y)) |> suppressWarnings() |> pull(1) |> class() %in% c("factor", "character") |> any()
)
)
# I HAVE TO REVESIT THIS
# stopifnot(
# "sccomp says: the groups in the formula (factor | group) should not be shared across factor groups" =
# !(
# # If I duplicated groups
# .y |> identical("(Intercept)") |> not() &&
# .data_ |> select(.y |> setdiff("(Intercept)")) |> lapply(class) != "numeric" &&
# .data_ |>
# select(.x, .y |> setdiff("(Intercept)")) |>
#
# # Drop the factor represented by the intercept if any
# mutate(`parameter` = .y |> setdiff("(Intercept)")) |>
# unite("factor_name", c(parameter, factor), sep = "", remove = FALSE) |>
# filter(factor_name %in% colnames(X)) |>
#
# # Count
# distinct() %>%
# set_names(as.character(1:ncol(.))) |>
# count(`1`) |>
# filter(n>1) |>
# nrow() |>
# gt(1)
#
# )
# )
}
))
random_effect_elements |>
nest(groupings = grouping ) |>
mutate(checked = map2(`factor`, groupings, ~{
# Check the same group spans multiple factors
stopifnot(
"sccomp says: the groups in the formula (factor | group) should be present in only one factor, including the intercept" =
!(
# If I duplicated groups
.y |> unlist() |> length() |> gt(1)
)
)
}))
}
#' @importFrom purrr when
#' @importFrom stats model.matrix
#' @importFrom tidyr expand_grid
#' @importFrom stringr str_detect
#' @importFrom stringr str_remove_all
#' @importFrom purrr reduce
#' @importFrom stats as.formula
#'
#' @keywords internal
#' @noRd
#'
data_spread_to_model_input =
function(
.data_spread, formula, .sample, .cell_type, .count,
truncation_ajustment = 1, approximate_posterior_inference ,
formula_variability = ~ 1,
contrasts = NULL,
bimodal_mean_variability_association = FALSE,
use_data = TRUE,
random_effect_elements){
# Define the variables as NULL to avoid CRAN NOTES
exposure <- NULL
design <- NULL
mat <- NULL
factor___numeric <- NULL
mean_idx <- NULL
design_matrix <- NULL
minus_sum <- NULL
group___numeric <- NULL
idx <- NULL
group___label <- NULL
parameter <- NULL
group <- NULL
design_matrix_col <- NULL
# Prepare column same enquo
.sample = enquo(.sample)
.cell_type = enquo(.cell_type)
.count = enquo(.count)
.grouping_for_random_effect =
random_effect_elements |>
pull(grouping) |>
unique()
if (length(.grouping_for_random_effect)==0 ) .grouping_for_random_effect = "random_effect"
X =
.data_spread |>
get_design_matrix(
# Drop random intercept
formula |>
as.character() |>
str_remove_all("\\+ ?\\(.+\\|.+\\)") |>
paste(collapse="") |>
as.formula(),
!!.sample
)
Xa =
.data_spread |>
get_design_matrix(
# Drop random intercept
formula_variability |>
as.character() |>
str_remove_all("\\+ ?\\(.+\\|.+\\)") |>
paste(collapse="") |>
as.formula() ,
!!.sample
)
XA = Xa %>%
as_tibble() %>%
distinct()
A = ncol(XA);
Ar = nrow(XA);
factor_names = parse_formula(formula)
factor_names_variability = parse_formula(formula_variability)
cell_cluster_names = .data_spread %>% select(-!!.sample, -any_of(factor_names), -exposure, -!!.grouping_for_random_effect) %>% colnames()
# Random intercept
if(nrow(random_effect_elements)>0 ) {
#check_random_effect_design(.data_spread, any_of(factor_names), random_effect_elements, formula, X)
random_effect_grouping = get_random_effect_design2(.data_spread, !!.sample, formula )
# Actual parameters, excluding for the sum to one parameters
is_random_effect = 1
random_effect_grouping =
random_effect_grouping |>
mutate(design_matrix = map(
design,
~ ..1 |>
select(!!.sample, group___label, value) |>
pivot_wider(names_from = group___label, values_from = value) |>
mutate(across(everything(), ~ .x |> replace_na(0)))
))
X_random_effect =
random_effect_grouping |>
pull(design_matrix) |>
_[[1]] |>
as_matrix(rownames = quo_name(.sample))
# For now that stan does not have tuples, I just allow max two levels
if(random_effect_grouping |> nrow() > 2) stop("sccomp says: at the moment sccomp allow max two groupings")
# This will be modularised with the new stan
if(random_effect_grouping |> nrow() > 1)
X_random_effect_2 =
random_effect_grouping |>
pull(design_matrix) |>
_[[2]] |>
as_matrix(rownames = quo_name(.sample))
else X_random_effect_2 = X_random_effect[,0,drop=FALSE]
n_random_eff = random_effect_grouping |> nrow()
ncol_X_random_eff =
random_effect_grouping |>
mutate(n = map_int(design, ~.x |> distinct(group___numeric) |> nrow())) |>
pull(n)
if(ncol_X_random_eff |> length() < 2) ncol_X_random_eff[2] = 0
# TEMPORARY
group_factor_indexes_for_covariance =
X_random_effect |>
colnames() |>
enframe(value = "parameter", name = "order") |>
separate(parameter, c("factor", "group"), "___", remove = FALSE) |>
complete(factor, group, fill = list(order=0)) |>
select(-parameter) |>
pivot_wider(names_from = group, values_from = order) |>
as_matrix(rownames = "factor")
n_groups = group_factor_indexes_for_covariance |> ncol()
# This will be modularised with the new stan
if(random_effect_grouping |> nrow() > 1)
group_factor_indexes_for_covariance_2 =
X_random_effect_2 |>
colnames() |>
enframe(value = "parameter", name = "order") |>
separate(parameter, c("factor", "group"), "___", remove = FALSE) |>
complete(factor, group, fill = list(order=0)) |>
select(-parameter) |>
pivot_wider(names_from = group, values_from = order) |>
as_matrix(rownames = "factor")
else group_factor_indexes_for_covariance_2 = matrix()[0,0, drop=FALSE]
n_groups = n_groups |> c(group_factor_indexes_for_covariance_2 |> ncol())
how_many_factors_in_random_design = list(group_factor_indexes_for_covariance, group_factor_indexes_for_covariance_2) |> map_int(nrow)
} else {
X_random_effect = matrix(rep(1, nrow(.data_spread)))[,0, drop=FALSE]
X_random_effect_2 = matrix(rep(1, nrow(.data_spread)))[,0, drop=FALSE] # This will be modularised with the new stan
is_random_effect = 0
ncol_X_random_eff = c(0,0)
n_random_eff = 0
n_groups = c(0,0)
how_many_factors_in_random_design = c(0,0)
group_factor_indexes_for_covariance = matrix()[0,0, drop=FALSE]
group_factor_indexes_for_covariance_2 = matrix()[0,0, drop=FALSE] # This will be modularised with the new stan
}
y = .data_spread %>% select(-any_of(factor_names), -exposure, -!!.grouping_for_random_effect) %>% as_matrix(rownames = quo_name(.sample))
# If proportion ix 0 issue
is_proportion = y |> as.numeric() |> max() |> between(0,1) |> all()
if(is_proportion){
y_proportion = y
y = y[0,,drop = FALSE]
}
else{
y = y
y_proportion = y[0,,drop = FALSE]
}
data_for_model =
list(
N = .data_spread %>% nrow(),
M = .data_spread %>% select(-!!.sample, -any_of(factor_names), -exposure, -!!.grouping_for_random_effect) %>% ncol(),
exposure = .data_spread$exposure,
is_proportion = is_proportion,
y = y,
y_proportion = y_proportion,
X = X,
XA = XA,
Xa = Xa,
C = ncol(X),
A = A,
Ar = Ar,
truncation_ajustment = truncation_ajustment,
is_vb = as.integer(approximate_posterior_inference),
bimodal_mean_variability_association = bimodal_mean_variability_association,
use_data = use_data,
# Random intercept
is_random_effect = is_random_effect,
ncol_X_random_eff = ncol_X_random_eff,
n_random_eff = n_random_eff,
n_groups = n_groups,
X_random_effect = X_random_effect,
X_random_effect_2 = X_random_effect_2,
group_factor_indexes_for_covariance = group_factor_indexes_for_covariance,
group_factor_indexes_for_covariance_2 = group_factor_indexes_for_covariance_2,
how_many_factors_in_random_design = how_many_factors_in_random_design,
# For parallel chains
grainsize = 1,
## LOO
enable_loo = FALSE
)
# Add censoring
data_for_model$is_truncated = 0
data_for_model$truncation_up = matrix(rep(-1, data_for_model$M * data_for_model$N), ncol = data_for_model$M)
data_for_model$truncation_down = matrix(rep(-1, data_for_model$M * data_for_model$N), ncol = data_for_model$M)
data_for_model$truncation_not_idx = seq_len(data_for_model$M*data_for_model$N)
data_for_model$TNS = length(data_for_model$truncation_not_idx)
data_for_model$truncation_not_idx_minimal = matrix(c(1,1), nrow = 1)[0,,drop=FALSE]
data_for_model$TNIM = 0
# Add parameter factor dictionary
data_for_model$factor_parameter_dictionary = tibble()
if(.data_spread |> select(any_of(parse_formula(formula))) |> lapply(class) %in% c("factor", "character") |> any())
data_for_model$factor_parameter_dictionary =
data_for_model$factor_parameter_dictionary |> bind_rows(
# For discrete
.data_spread |>
select(any_of(parse_formula(formula))) |>
distinct() |>
# Drop numerical
select_if(function(x) !is.numeric(x)) |>
pivot_longer(everything(), names_to = "factor", values_to = "parameter") %>%
unite("design_matrix_col", c(`factor`, parameter), sep="", remove = FALSE) |>
select(-parameter) |>
filter(design_matrix_col %in% colnames(data_for_model$X)) %>%
distinct()
)
# For continuous
if(.data_spread |> select(all_of(parse_formula(formula))) |> lapply(class) |> equals("numeric") |> any())
data_for_model$factor_parameter_dictionary =
data_for_model$factor_parameter_dictionary |>
bind_rows(
tibble(
design_matrix_col = .data_spread |>
select(all_of(parse_formula(formula))) |>
distinct() |>
# Drop numerical
select_if(function(x) is.numeric(x)) |>
names()
) |>
mutate(`factor` = design_matrix_col)
)
# If constrasts is set it is a bit more complicated
if(! is.null(contrasts))
data_for_model$factor_parameter_dictionary =
data_for_model$factor_parameter_dictionary |>
distinct() |>
expand_grid(parameter=contrasts) |>
filter(str_detect(parameter, design_matrix_col )) |>
select(-design_matrix_col) |>
rename(design_matrix_col = parameter) |>
distinct()
data_for_model$intercept_in_design = X[,1] |> unique() |> identical(1)
if (data_for_model$intercept_in_design | length(factor_names_variability) == 0) {
data_for_model$A_intercept_columns = 1
} else {
data_for_model$A_intercept_columns =
.data_spread |>
select(any_of(factor_names[1])) |>
distinct() |>
nrow()
}
if (data_for_model$intercept_in_design ) {
data_for_model$B_intercept_columns = 1
} else {
data_for_model$B_intercept_columns =
.data_spread |>
select(any_of(factor_names[1])) |>
distinct() |>
nrow()
}
# Return
data_for_model
}
data_to_spread = function(.data, formula, .sample, .cell_type, .count, .grouping_for_random_effect){
.sample = enquo(.sample)
.cell_type = enquo(.cell_type)
.count = enquo(.count)
.grouping_for_random_effect = .grouping_for_random_effect |> map(~ .x |> quo_name() ) |> unlist()
is_proportion = .data |> pull(!!.count) |> max() <= 1
.data =
.data |>
nest(data = -!!.sample)
# If proportions exposure = 1
if(is_proportion) .data = .data |> mutate(exposure = 1)
else
.data =
.data |>
mutate(exposure = map_int(data, ~ .x |> pull(!!.count) |> sum() ))
.data |>
unnest(data) |>
select(!!.sample, !!.cell_type, exposure, !!.count, parse_formula(formula), any_of(.grouping_for_random_effect)) |>
spread(!!.cell_type, !!.count)
}
#' @importFrom purrr when
#' @importFrom stats model.matrix
#'
#' @keywords internal
#' @noRd
#'
data_simulation_to_model_input =
function(.data, formula, .sample, .cell_type, .exposure, .coefficients, truncation_ajustment = 1, approximate_posterior_inference ){
# Define the variables as NULL to avoid CRAN NOTES
sd <- NULL
. <- NULL
# Prepare column same enquo
.sample = enquo(.sample)
.cell_type = enquo(.cell_type)
.exposure = enquo(.exposure)
.coefficients = enquo(.coefficients)
factor_names = parse_formula(formula)
sample_data =
.data %>%
select(!!.sample, any_of(factor_names)) %>%
distinct() %>%
arrange(!!.sample)
X =
sample_data %>%
model.matrix(formula, data=.) %>%
apply(2, function(x) {
if(sd(x)==0 ) x
else x |> scale(scale=FALSE)
} ) %>%
{
.x = (.)
rownames(.x) = sample_data %>% pull(!!.sample)
.x
}
if(factor_names == "1") XA = X[,1, drop=FALSE]
else XA = X[,c(1,2), drop=FALSE]
XA = XA |>
as_tibble() |>
distinct()
cell_cluster_names =
.data %>%
distinct(!!.cell_type) %>%
arrange(!!.cell_type) %>%
pull(!!.cell_type)
coefficients =
.data %>%
select(!!.cell_type, !!.coefficients) %>%
unnest(!!.coefficients) %>%
distinct() %>%
arrange(!!.cell_type) %>%
as_matrix(rownames = quo_name(.cell_type)) %>%
t()
list(
N = .data %>% distinct(!!.sample) %>% nrow(),
M = .data %>% distinct(!!.cell_type) %>% nrow(),
exposure = .data %>% distinct(!!.sample, !!.exposure) %>% arrange(!!.sample) %>% pull(!!.exposure),
X = X,
XA = XA,
C = ncol(X),
A = ncol(XA),
beta = coefficients
)
}
#' Choose the number of chains baed on how many draws we need from the posterior distribution
#' Because there is a fix cost (warmup) to starting a new chain,
#' we need to use the minimum amount that we can parallelise
#' @param how_many_posterior_draws A real number of posterior draws needed
#' @param max_number_to_check A sane upper plateau
#'
#' @keywords internal
#' @noRd
#'
#' @return A Stan fit object
find_optimal_number_of_chains = function(how_many_posterior_draws = 100,
max_number_to_check = 100, warmup = 200, parallelisation_start_penalty = 100) {
# Define the variables as NULL to avoid CRAN NOTES
chains <- NULL
chains_df =
tibble(chains = seq_len(max_number_to_check)) %>%
mutate(tot = (how_many_posterior_draws / chains) + warmup + (parallelisation_start_penalty * chains))
d1 <- diff(chains_df$tot) / diff(seq_len(nrow(chains_df))) # first derivative
abs(d1) %>% order() %>% .[1] # Find derivative == 0
}
get.elbow.points.indices <- function(x, y, threshold) {
# From https://stackoverflow.com/questions/41518870/finding-the-elbow-knee-in-a-curve
d1 <- diff(y) / diff(x) # first derivative
d2 <- diff(d1) / diff(x[-1]) # second derivative
indices <- which(abs(d2) > threshold)
return(indices)
}
#' @importFrom magrittr divide_by
#' @importFrom magrittr multiply_by
#' @importFrom stats C
#'
#' @keywords internal
#' @noRd
#'
get_probability_non_zero_OLD = function(.data, prefix = "", test_above_logit_fold_change = 0){
# Define the variables as NULL to avoid CRAN NOTES
.draw <- NULL
M <- NULL
C_name <- NULL
bigger_zero <- NULL
smaller_zero <- NULL
probability_column_name = sprintf("%s_prob_H0", prefix) %>% as.symbol()
total_draws = .data %>% pull(2) %>% .[[1]] %>% distinct(.draw) %>% nrow()
.data %>%
unnest(2 ) %>%
filter(C ==2) %>%
nest(data = -c(M, C_name)) %>%
mutate(
bigger_zero = map_int(data, ~ .x %>% filter(.value>test_above_logit_fold_change) %>% nrow),
smaller_zero = map_int(data, ~ .x %>% filter(.value< -test_above_logit_fold_change) %>% nrow)
) %>%
rowwise() %>%
mutate(
!!probability_column_name :=
1 - (
max(bigger_zero, smaller_zero) %>%
#max(1) %>%
divide_by(total_draws)
)
) %>%
ungroup() %>%
select(M, !!probability_column_name)
# %>%
# mutate(false_discovery_rate = cummean(prob_non_zero))
}
#' @importFrom magrittr divide_by
#' @importFrom magrittr multiply_by
#' @importFrom stats C
#' @importFrom stats setNames
#'
#' @keywords internal
#' @noRd
#'
get_probability_non_zero_ = function(fit, parameter, prefix = "", test_above_logit_fold_change = 0){
# Define the variables as NULL to avoid CRAN NOTES
M <- NULL
C_name <- NULL
bigger_zero <- NULL
smaller_zero <- NULL
draws = fit$draws(
variables = parameter,
inc_warmup = FALSE,
format = getOption("cmdstanr_draws_format", "draws_matrix")
)
total_draws = dim(draws)[1]
bigger_zero =
draws %>%
apply(2, function(x) (x>test_above_logit_fold_change) %>% which %>% length)
smaller_zero =
draws %>%
apply(2, function(x) (x< -test_above_logit_fold_change) %>% which %>% length)
(1 - (pmax(bigger_zero, smaller_zero) / total_draws)) %>%
enframe() %>%
tidyr::extract(name, c("C", "M"), ".+\\[([0-9]+),([0-9]+)\\]") %>%
mutate(across(c(C, M), ~ as.integer(.x))) %>%
tidyr::spread(C, value)
}
get_probability_non_zero = function(draws, test_above_logit_fold_change = 0, probability_column_name){
draws %>%
with_groups(c(M, C_name), ~ .x |> summarise(
bigger_zero = which(.value>test_above_logit_fold_change) |> length(),
smaller_zero = which(.value< -test_above_logit_fold_change) |> length(),
n=n()
)) |>
mutate(!!as.symbol(probability_column_name) := (1 - (pmax(bigger_zero, smaller_zero) / n)))
}
#' @keywords internal
#' @noRd
#'
parse_generated_quantities = function(rng, number_of_draws = 1){
# Define the variables as NULL to avoid CRAN NOTES
.draw <- NULL
N <- NULL
.value <- NULL
generated_counts <- NULL
M <- NULL
generated_proportions <- NULL
draws_to_tibble_x_y(rng, "counts", "N", "M", number_of_draws) %>%
with_groups(c(.draw, N), ~ .x %>% mutate(generated_proportions = .value/max(1, sum(.value)))) %>%
filter(.draw<= number_of_draws) %>%
rename(generated_counts = .value, replicate = .draw) %>%
mutate(generated_counts = as.integer(generated_counts)) %>%
select(M, N, generated_proportions, generated_counts, replicate)
}
#' design_matrix_and_coefficients_to_simulation
#'
#' @description Create simulation from design matrix and coefficient matrix
#'
#' @importFrom dplyr left_join
#' @importFrom tidyr expand_grid
#'
#' @keywords internal
#' @noRd
#'
#' @param design_matrix A matrix
#' @param coefficient_matrix A matrix
#'
#' @return A data frame
#'
#'
#'
design_matrix_and_coefficients_to_simulation = function(
design_matrix, coefficient_matrix, .estimate_object
){
# Define the variables as NULL to avoid CRAN NOTES
cell_type <- NULL
beta_1 <- NULL
beta_2 <- NULL
design_df = as.data.frame(design_matrix)
coefficient_df = as.data.frame(coefficient_matrix)
rownames(design_df) = sprintf("sample_%s", seq_len(nrow(design_df)))
colnames(design_df) = sprintf("factor_%s", seq_len(ncol(design_df)))
rownames(coefficient_df) = sprintf("cell_type_%s", seq_len(nrow(coefficient_df)))
colnames(coefficient_df) = sprintf("beta_%s", seq_len(ncol(coefficient_df)))
input_data =
expand_grid(
sample = rownames(design_df),
cell_type = rownames(coefficient_df)
) |>
left_join(design_df |> as_tibble(rownames = "sample") , by = "sample") |>
left_join(coefficient_df |>as_tibble(rownames = "cell_type"), by = "cell_type")
simulate_data(.data = input_data,
.estimate_object = .estimate_object,
formula_composition = ~ factor_1 ,
.sample = sample,
.cell_group = cell_type,
.coefficients = c(beta_1, beta_2),
mcmc_seed = sample(1e5, 1)
)
}
#' @importFrom rlang ensym
#' @noRd
class_list_to_counts = function(.data, .sample, .cell_group){
.sample_for_tidyr = ensym(.sample)
.cell_group_for_tidyr = ensym(.cell_group)
.sample = enquo(.sample)
.cell_group = enquo(.cell_group)
.data %>%
count(!!.sample,
!!.cell_group,
name = "count") %>%
complete(
!!.sample_for_tidyr,!!.cell_group_for_tidyr,
fill = list(count = 0)
) %>%
mutate(count = as.integer(count))
}
#' @importFrom dplyr cummean
#' @noRd
get_FDR = function(x){
# Define the variables as NULL to avoid CRAN NOTES
value <- NULL
name <- NULL
FDR <- NULL
enframe(x) %>%
arrange(value) %>%
mutate(FDR = cummean(value)) %>%
arrange(name) %>%
pull(FDR)
}
#' Plot 1D Intervals for Cell-group Effects
#'
#' This function creates a series of 1D interval plots for cell-group effects, highlighting significant differences based on a given significance threshold.
#'
#' @param .data Data frame containing the main data.
#' @param significance_threshold Numeric value specifying the significance threshold for highlighting differences. Default is 0.025.
#' @param test_composition_above_logit_fold_change A positive integer. It is the effect threshold used for the hypothesis test. A value of 0.2 correspond to a change in cell proportion of 10% for a cell type with baseline proportion of 50%. That is, a cell type goes from 45% to 50%. When the baseline proportion is closer to 0 or 1 this effect thrshold has consistent value in the logit uncontrained scale.
#' @importFrom patchwork wrap_plots
#' @importFrom forcats fct_reorder
#' @importFrom tidyr drop_na
#'
#' @export
#'
#' @return A combined plot of 1D interval plots.
#' @examples
#' # Example usage:
#' # plot_1D_intervals(.data, "cell_group", 0.025, theme_minimal())
plot_1D_intervals = function(.data, significance_threshold = 0.05, test_composition_above_logit_fold_change = .data |> attr("test_composition_above_logit_fold_change")){
# Define the variables as NULL to avoid CRAN NOTES
parameter <- NULL
estimate <- NULL
value <- NULL
.cell_group = attr(.data, ".cell_group")
# Check if test have been done
if(.data |> select(ends_with("FDR")) |> ncol() |> equals(0))
stop("sccomp says: to produce plots, you need to run the function sccomp_test() on your estimates.")
plot_list =
.data |>
filter(parameter != "(Intercept)") |>
# Reshape data
select(-contains("n_eff"), -contains("R_k_hat")) |>
pivot_longer(c(contains("c_"), contains("v_")), names_sep = "_", names_to = c("which", "estimate")) |>
pivot_wider(names_from = estimate, values_from = value) |>
# Nest data by parameter and which
nest(data = -c(parameter, which)) |>
mutate(plot = pmap(
list(data, which, parameter),
~ {
# Check if there are any statistics to plot
if(..1 |> filter(!effect |> is.na()) |> nrow() |> equals(0))
return(NA)
# Create ggplot for each nested data
ggplot(..1, aes(x = effect, y = fct_reorder(!!.cell_group, effect))) +
geom_vline(xintercept = test_composition_above_logit_fold_change, colour = "grey") +
geom_vline(xintercept = -test_composition_above_logit_fold_change, colour = "grey") +
geom_errorbar(aes(xmin = lower, xmax = upper, color = FDR < significance_threshold)) +
geom_point() +
scale_color_brewer(palette = "Set1") +
xlab("Credible interval of the slope") +
ylab("Cell group") +
ggtitle(sprintf("%s %s", ..2, ..3)) +
multipanel_theme +
theme(legend.position = "bottom")
}
)) %>%
# Filter out NA plots
filter(!plot |> is.na()) |>
pull(plot)
# Combine all individual plots into one plot
plot_list |>
wrap_plots(ncol = plot_list |> length() |> sqrt() |> ceiling())
}
#' Plot 2D Intervals for Mean-Variance Association
#'
#' This function creates a 2D interval plot for mean-variance association, highlighting significant differences based on a given significance threshold.
#'
#' @param .data Data frame containing the main data.
#' @param significance_threshold Numeric value specifying the significance threshold for highlighting differences. Default is 0.025.
#' @param test_composition_above_logit_fold_change A positive integer. It is the effect threshold used for the hypothesis test. A value of 0.2 correspond to a change in cell proportion of 10% for a cell type with baseline proportion of 50%. That is, a cell type goes from 45% to 50%. When the baseline proportion is closer to 0 or 1 this effect thrshold has consistent value in the logit uncontrained scale.
#'
#'
#' @importFrom dplyr filter arrange mutate if_else row_number
#' @importFrom ggplot2 ggplot geom_vline geom_hline geom_errorbar geom_point annotate aes facet_wrap
#' @importFrom ggrepel geom_text_repel
#' @importFrom scales trans_new
#' @importFrom stringr str_replace
#' @importFrom stats quantile
#' @importFrom magrittr equals
#'
#' @export
#'
#' @return A ggplot object representing the 2D interval plot.
#' @examples
#' # Example usage:
#' # plot_2D_intervals(.data, "cell_group", theme_minimal(), 0.025)
plot_2D_intervals = function(.data, significance_threshold = 0.05, test_composition_above_logit_fold_change = .data |> attr("test_composition_above_logit_fold_change")){
# Define the variables as NULL to avoid CRAN NOTES
v_effect <- NULL
parameter <- NULL
c_effect <- NULL
c_lower <- NULL
c_upper <- NULL
c_FDR <- NULL
v_lower <- NULL
v_upper <- NULL
v_FDR <- NULL
cell_type_label <- NULL
multipanel_theme <- NULL
.cell_group = attr(.data, ".cell_group")
# Check if test have been done
if(.data |> select(ends_with("FDR")) |> ncol() |> equals(0))
stop("sccomp says: to produce plots, you need to run the function sccomp_test() on your estimates.")
# Mean-variance association
.data %>%
# Filter where variance is inferred
filter(!is.na(v_effect)) %>%
# Add labels for significant cell groups
with_groups(
parameter,
~ .x %>%
arrange(c_FDR) %>%
mutate(cell_type_label = if_else(row_number() <= 3 & c_FDR < significance_threshold & parameter != "(Intercept)", !!.cell_group, ""))
) %>%
with_groups(
parameter,
~ .x %>%
arrange(v_FDR) %>%
mutate(cell_type_label = if_else((row_number() <= 3 & v_FDR < significance_threshold & parameter != "(Intercept)"), !!.cell_group, cell_type_label))
) %>%
{
.x = (.)
# Plot
ggplot(.x, aes(c_effect, v_effect)) +
# Add vertical and horizontal lines
geom_vline(xintercept = c(-test_composition_above_logit_fold_change, test_composition_above_logit_fold_change), colour = "grey", linetype = "dashed", linewidth = 0.3) +
geom_hline(yintercept = c(-test_composition_above_logit_fold_change, test_composition_above_logit_fold_change), colour = "grey", linetype = "dashed", linewidth = 0.3) +
# Add error bars
geom_errorbar(aes(xmin = `c_lower`, xmax = `c_upper`, color = `c_FDR` < significance_threshold, alpha = `c_FDR` < significance_threshold), linewidth = 0.2) +
geom_errorbar(aes(ymin = v_lower, ymax = v_upper, color = `v_FDR` < significance_threshold, alpha = `v_FDR` < significance_threshold), linewidth = 0.2) +
# Add points
geom_point(size = 0.2) +
# Add annotations
annotate("text", x = 0, y = 3.5, label = "Variable", size = 2) +
annotate("text", x = 5, y = 0, label = "Abundant", size = 2, angle = 270) +
# Add text labels for significant cell groups
geom_text_repel(aes(c_effect, -v_effect, label = cell_type_label), size = 2.5, data = .x %>% filter(cell_type_label != "")) +
# Set color and alpha scales
scale_color_manual(values = c("#D3D3D3", "#E41A1C")) +
scale_alpha_manual(values = c(0.4, 1)) +
# Facet by parameter
facet_wrap(~parameter, scales = "free") +
# Apply custom theme
multipanel_theme
}
}
#' Plot Boxplot of Cell-group Proportion
#'
#' This function creates a boxplot of cell-group proportions, optionally highlighting significant differences based on a given significance threshold.
#'
#' @param .data Data frame containing the main data.
#' @param data_proportion Data frame containing proportions of cell groups.
#' @param factor_of_interest A factor indicating the biological condition of interest.
#' @param .cell_group The cell group to be analysed.
#' @param .sample The sample identifier.
#' @param significance_threshold Numeric value specifying the significance threshold for highlighting differences. Default is 0.025.
#' @param my_theme A ggplot2 theme object to be applied to the plot.
#' @importFrom scales trans_new
#' @importFrom stringr str_replace
#' @importFrom stats quantile
#'
#'
#' @return A ggplot object representing the boxplot.
#' @examples
#' # Example usage:
#' # plot_boxplot(.data, data_proportion, "condition", "cell_group", "sample", 0.025, theme_minimal())
plot_boxplot = function(
.data, data_proportion, factor_of_interest, .cell_group,
.sample, significance_threshold = 0.05, my_theme
){
# Define the variables as NULL to avoid CRAN NOTES
stats_name <- NULL
parameter <- NULL
stats_value <- NULL
count_data <- NULL
generated_proportions <- NULL
proportion <- NULL
name <- NULL
outlier <- NULL
# Function to calculate boxplot statistics
calc_boxplot_stat <- function(x) {
coef <- 1.5
n <- sum(!is.na(x))
# Calculate quantiles
stats <- quantile(x, probs = c(0.0, 0.25, 0.5, 0.75, 1.0))
names(stats) <- c("ymin", "lower", "middle", "upper", "ymax")
iqr <- diff(stats[c(2, 4)])
# Set whiskers
outliers <- x < (stats[2] - coef * iqr) | x > (stats[4] + coef * iqr)
if (any(outliers)) {
stats[c(1, 5)] <- range(c(stats[2:4], x[!outliers]), na.rm = TRUE)
}
return(stats)
}
# Function to remove leading zero from labels
dropLeadingZero <- function(l){ stringr::str_replace(l, '0(?=.)', '') }
# Define square root transformation and its inverse
S_sqrt <- function(x){sign(x)*sqrt(abs(x))}
IS_sqrt <- function(x){x^2*sign(x)}
S_sqrt_trans <- function() scales::trans_new("S_sqrt",S_sqrt,IS_sqrt)
.cell_group = enquo(.cell_group)
.sample = enquo(.sample)
# Prepare significance colors
significance_colors =
.data %>%
pivot_longer(
c(contains("c_"), contains("v_")),
names_pattern = "([cv])_([a-zA-Z0-9]+)",
names_to = c("which", "stats_name"),
values_to = "stats_value"
) %>%
filter(stats_name == "FDR") %>%
filter(parameter != "(Intercept)") %>%
filter(stats_value < significance_threshold) %>%
filter(`factor` == factor_of_interest)
if(nrow(significance_colors) > 0){
if(.data |> attr("contrasts") |> is.null())
significance_colors =
significance_colors %>%
unite("name", c(which, parameter), remove = FALSE) %>%
distinct() %>%
# Get clean parameter
mutate(!!as.symbol(factor_of_interest) := str_replace(parameter, sprintf("^%s", `factor`), "")) %>%
with_groups(c(!!.cell_group, !!as.symbol(factor_of_interest)), ~ .x %>% summarise(name = paste(name, collapse = ", ")))
else
significance_colors =
significance_colors |>
mutate(count_data = map(count_data, ~ .x |> select(all_of(factor_of_interest)) |> distinct())) |>
unnest(count_data) |>
# Filter relevant parameters
mutate( !!as.symbol(factor_of_interest) := as.character(!!as.symbol(factor_of_interest) ) ) |>
filter(str_detect(parameter, !!as.symbol(factor_of_interest) )) |>
# Rename
select(!!.cell_group, !!as.symbol(factor_of_interest), name = parameter) |>
# Merge contrasts
with_groups(c(!!.cell_group, !!as.symbol(factor_of_interest)), ~ .x %>% summarise(name = paste(name, collapse = ", ")))
}
my_boxplot = ggplot()
if("fit" %in% names(attributes(.data))){
simulated_proportion =
.data |>
sccomp_replicate(number_of_draws = 100) |>
left_join(data_proportion %>% distinct(!!as.symbol(factor_of_interest), !!.sample, !!.cell_group))
my_boxplot = my_boxplot +
# Add boxplot for simulated proportions
stat_summary(
aes(!!as.symbol(factor_of_interest), (generated_proportions)),
fun.data = calc_boxplot_stat, geom="boxplot",
outlier.shape = NA, outlier.color = NA,outlier.size = 0,
fatten = 0.5, lwd=0.2,
data =
simulated_proportion %>%
inner_join(data_proportion %>% distinct(!!as.symbol(factor_of_interest), !!.cell_group)),
color="blue"
)
}
if(nrow(significance_colors) == 0 ||
length(intersect(
significance_colors |> pull(!!as.symbol(factor_of_interest)),
data_proportion |> pull(!!as.symbol(factor_of_interest))
)) == 0){
my_boxplot=
my_boxplot +
# Add boxplot without significance colors
geom_boxplot(
aes(!!as.symbol(factor_of_interest), proportion, group=!!as.symbol(factor_of_interest), fill = NULL),
outlier.shape = NA, outlier.color = NA,outlier.size = 0,
data =
data_proportion |>
mutate(!!as.symbol(factor_of_interest) := as.character(!!as.symbol(factor_of_interest))) ,
fatten = 0.5,
lwd=0.5,
)
} else {
my_boxplot=
my_boxplot +
# Add boxplot with significance colors
geom_boxplot(
aes(!!as.symbol(factor_of_interest), proportion, group=!!as.symbol(factor_of_interest), fill = name),
outlier.shape = NA, outlier.color = NA,outlier.size = 0,
data =
data_proportion |>
mutate(!!as.symbol(factor_of_interest) := as.character(!!as.symbol(factor_of_interest))) %>%
left_join(significance_colors, by = c(quo_name(.cell_group), factor_of_interest)),
fatten = 0.5,
lwd=0.5,
)
}
my_boxplot +
# Add jittered points for individual data
geom_jitter(
aes(!!as.symbol(factor_of_interest), proportion, shape=outlier, color=outlier, group=!!as.symbol(factor_of_interest)),
data = data_proportion,
position=position_jitterdodge(jitter.height = 0, jitter.width = 0.2),
size = 0.5
) +
# Facet wrap by cell group
facet_wrap(
vars(!!.cell_group),
scales = "free_y",
nrow = 4
) +
scale_color_manual(values = c("black", "#e11f28")) +
scale_y_continuous(trans=S_sqrt_trans(), labels = dropLeadingZero) +
scale_fill_discrete(na.value = "white") +
xlab("Biological condition") +
ylab("Cell-group proportion") +
guides(color="none", alpha="none", size="none") +
labs(fill="Significant difference") +
ggtitle("Note: Be careful judging significance (or outliers) visually for lowly abundant cell groups. \nVisualising proportion hides the uncertainty characteristic of count data, that a count-based statistical model can estimate.") +
my_theme +
theme(axis.text.x = element_text(angle=20, hjust = 1), title = element_text(size = 3))
}
#' Plot Scatterplot of Cell-group Proportion
#'
#' This function creates a scatterplot of cell-group proportions, optionally highlighting significant differences based on a given significance threshold.
#'
#' @param .data Data frame containing the main data.
#' @param data_proportion Data frame containing proportions of cell groups.
#' @param factor_of_interest A factor indicating the biological condition of interest.
#' @param .cell_group The cell group to be analysed.
#' @param .sample The sample identifier.
#' @param significance_threshold Numeric value specifying the significance threshold for highlighting differences. Default is 0.025.
#' @param my_theme A ggplot2 theme object to be applied to the plot.
#' @importFrom scales trans_new
#' @importFrom stringr str_replace
#' @importFrom stats quantile
#' @importFrom magrittr equals
#'
#'
#' @return A ggplot object representing the scatterplot.
#' @examples
#' # Example usage:
#' # plot_scatterplot(.data, data_proportion, "condition", "cell_group", "sample", 0.025, theme_minimal())
plot_scatterplot = function(
.data, data_proportion, factor_of_interest, .cell_group,
.sample, significance_threshold = 0.05, my_theme
){
# Define the variables as NULL to avoid CRAN NOTES
stats_name <- NULL
parameter <- NULL
stats_value <- NULL
count_data <- NULL
generated_proportions <- NULL
proportion <- NULL
name <- NULL
outlier <- NULL
# Function to remove leading zero from labels
dropLeadingZero <- function(l){ stringr::str_replace(l, '0(?=.)', '') }
# Define square root transformation and its inverse
S_sqrt <- function(x){sign(x)*sqrt(abs(x))}
IS_sqrt <- function(x){x^2*sign(x)}
S_sqrt_trans <- function() scales::trans_new("S_sqrt",S_sqrt,IS_sqrt)
.cell_group = enquo(.cell_group)
.sample = enquo(.sample)
# Prepare significance colors
significance_colors =
.data %>%
pivot_longer(
c(contains("c_"), contains("v_")),
names_pattern = "([cv])_([a-zA-Z0-9]+)",
names_to = c("which", "stats_name"),
values_to = "stats_value"
) %>%
filter(stats_name == "FDR") %>%
filter(parameter != "(Intercept)") %>%
filter(stats_value < significance_threshold) %>%
filter(`factor` == factor_of_interest)
if(nrow(significance_colors) > 0){
if(.data |> attr("contrasts") |> is.null())
significance_colors =
significance_colors %>%
unite("name", c(which, parameter), remove = FALSE) %>%
distinct() %>%
# Get clean parameter
mutate(!!as.symbol(factor_of_interest) := str_replace(parameter, sprintf("^%s", `factor`), "")) %>%
with_groups(c(!!.cell_group, !!as.symbol(factor_of_interest)), ~ .x %>% summarise(name = paste(name, collapse = ", ")))
else
significance_colors =
significance_colors |>
mutate(count_data = map(count_data, ~ .x |> select(all_of(factor_of_interest)) |> distinct())) |>
unnest(count_data) |>
# Filter relevant parameters
mutate( !!as.symbol(factor_of_interest) := as.character(!!as.symbol(factor_of_interest) ) ) |>
filter(str_detect(parameter, !!as.symbol(factor_of_interest) )) |>
# Rename
select(!!.cell_group, !!as.symbol(factor_of_interest), name = parameter) |>
# Merge contrasts
with_groups(c(!!.cell_group, !!as.symbol(factor_of_interest)), ~ .x %>% summarise(name = paste(name, collapse = ", ")))
}
my_scatterplot = ggplot()
if("fit" %in% names(attributes(.data))){
simulated_proportion =
.data |>
sccomp_replicate(number_of_draws = 1000) |>
left_join(data_proportion %>% distinct(!!as.symbol(factor_of_interest), !!.sample, !!.cell_group))
my_scatterplot =
my_scatterplot +
# Add smoothed line for simulated proportions
geom_smooth(
aes(!!as.symbol(factor_of_interest), (generated_proportions)),
lwd=0.2,
data =
simulated_proportion %>%
inner_join(data_proportion %>% distinct(!!as.symbol(factor_of_interest), !!.cell_group, !!.sample)) ,
color="blue", fill="blue",
span = 1
)
}
if(
nrow(significance_colors)==0 ||
significance_colors |>
pull(!!as.symbol(factor_of_interest)) |>
intersect(
data_proportion |>
pull(!!as.symbol(factor_of_interest))
) |>
length() |>
equals(0)
) {
my_scatterplot=
my_scatterplot +
# Add smoothed line without significance colors
geom_smooth(
aes(!!as.symbol(factor_of_interest), proportion, fill = NULL),
data =
data_proportion ,
lwd=0.5,
color = "black",
span = 1
)
} else {
my_scatterplot=
my_scatterplot +
# Add smoothed line with significance colors
geom_smooth(
aes(!!as.symbol(factor_of_interest), proportion, fill = name),
outlier.shape = NA, outlier.color = NA,outlier.size = 0,
data = data_proportion ,
fatten = 0.5,
lwd=0.5,
color = "black",
span = 1
)
}
my_scatterplot +
# Add jittered points for individual data
geom_point(
aes(!!as.symbol(factor_of_interest), proportion, shape=outlier, color=outlier),
data = data_proportion,
position=position_jitterdodge(jitter.height = 0, jitter.width = 0.2),
size = 0.5
) +
# Facet wrap by cell group
facet_wrap(
vars(!!.cell_group),
scales = "free_y",
nrow = 4
) +
scale_color_manual(values = c("black", "#e11f28")) +
scale_y_continuous(trans=S_sqrt_trans(), labels = dropLeadingZero) +
scale_fill_discrete(na.value = "white") +
xlab("Biological condition") +
ylab("Cell-group proportion") +
guides(color="none", alpha="none", size="none") +
labs(fill="Significant difference") +
ggtitle("Note: Be careful judging significance (or outliers) visually for lowly abundant cell groups. \nVisualising proportion hides the uncertainty characteristic of count data, that a count-based statistical model can estimate.") +
my_theme +
theme(axis.text.x = element_text(angle=20, hjust = 1), title = element_text(size = 3))
}
draws_to_statistics = function(draws, false_positive_rate, test_composition_above_logit_fold_change, .cell_group, prefix = ""){
# Define the variables as NULL to avoid CRAN NOTES
M <- NULL
parameter <- NULL
bigger_zero <- NULL
smaller_zero <- NULL
lower <- NULL
effect <- NULL
upper <- NULL
pH0 <- NULL
FDR <- NULL
n_eff <- NULL
R_k_hat <- NULL
.cell_group = enquo(.cell_group)
draws =
draws |>
with_groups(c(!!.cell_group, M, parameter), ~ .x |> summarise(
lower = quantile(.value, false_positive_rate/2),
effect = quantile(.value, 0.5),
upper = quantile(.value, 1-(false_positive_rate/2)),
bigger_zero = which(.value>test_composition_above_logit_fold_change) |> length(),
smaller_zero = which(.value< -test_composition_above_logit_fold_change) |> length(),
R_k_hat = unique(R_k_hat),
n_eff = unique(n_eff),
n=n()
)) |>
# Calculate probability non 0
mutate(pH0 = (1 - (pmax(bigger_zero, smaller_zero) / n))) |>
with_groups(parameter, ~ mutate(.x, FDR = get_FDR(pH0))) |>
select(!!.cell_group, M, parameter, lower, effect, upper, pH0, FDR, any_of(c("n_eff", "R_k_hat"))) |>
suppressWarnings()
# Setting up names separately because |> is not flexible enough
draws |>
setNames(c(colnames(draws)[1:3], sprintf("%s%s", prefix, colnames(draws)[4:ncol(draws)])))
}
enquos_from_list_of_symbols <- function(...) {
enquos(...)
}
contrasts_to_enquos = function(contrasts){
# Define the variables as NULL to avoid CRAN NOTES
. <- NULL
contrasts |> enquo() |> quo_names() |> syms() %>% do.call(enquos_from_list_of_symbols, .)
}
#' Mutate Data Frame Based on Expression List
#'
#' @description
#' `mutate_from_expr_list` takes a data frame and a list of formula expressions,
#' and mutates the data frame based on these expressions. It allows for ignoring
#' errors during the mutation process.
#'
#' @param x A data frame to be mutated.
#' @param formula_expr A named list of formula expressions used for mutation.
#' @param ignore_errors Logical flag indicating whether to ignore errors during mutation.
#'
#' @return A mutated data frame with added or modified columns based on `formula_expr`.
#'
#' @details
#' The function performs various checks and transformations on the formula expressions,
#' ensuring that the specified transformations are valid and can be applied to the data frame.
#' It supports advanced features like handling special characters in column names and intelligent
#' parsing of formulas.
#'
#' @importFrom purrr map2_dfc
#' @importFrom tibble add_column
#' @importFrom tidyselect last_col
#' @importFrom dplyr mutate
#' @importFrom stringr str_subset
#'
#' @noRd
#'
mutate_from_expr_list = function(x, formula_expr, ignore_errors = TRUE){
if(formula_expr |> names() |> is.null())
names(formula_expr) = formula_expr
# Check if all elements of contrasts are in the parameter
parameter_names = x |> colnames()
# Creating a named vector where the names are the strings to be replaced
# and the values are empty strings
# Using str_replace_all to replace each instance of the strings in A with an empty string in B
contrasts_elements <-
formula_expr |>
# Remove fractions
str_remove_all_ignoring_if_inside_backquotes("[0-9]+/[0-9]+ ?\\*") |>
# Remove decimals
str_remove_all_ignoring_if_inside_backquotes("[-+]?[0-9]+\\.[0-9]+ ?\\*") |>
str_split_ignoring_if_inside_backquotes("\\+|-|\\*") |>
unlist() |>
str_remove_all_ignoring_if_inside_backquotes("[\\(\\) ]")
# Check is backquoted are not used
require_back_quotes = !contrasts_elements |> str_remove_all("`") |> contains_only_valid_chars_for_column()
has_left_back_quotes = contrasts_elements |> str_detect("^`")
has_right_back_quotes = contrasts_elements |> str_detect("`$")
if_true_not_good = require_back_quotes & !(has_left_back_quotes & has_right_back_quotes)
if(any(if_true_not_good))
warning(sprintf("sccomp says: for columns which have special characters e.g. %s, you need to use surrounding backquotes ``.", paste(contrasts_elements[!if_true_not_good], sep=", ")))
# Check if columns exist
contrasts_not_in_the_model =
contrasts_elements |>
str_remove_all("`") |>
setdiff(parameter_names)
contrasts_not_in_the_model = contrasts_not_in_the_model[contrasts_not_in_the_model!=""]
if(length(contrasts_not_in_the_model) > 0 & !ignore_errors)
warning(sprintf("sccomp says: These components of your contrasts are not present in the model as parameters: %s. Factors including special characters, e.g. \"(Intercept)\" require backquotes e.g. \"`(Intercept)`\" ", paste(contrasts_not_in_the_model, sep = ", ")))
# Calculate
if(ignore_errors) my_mutate = mutate_ignore_error
else my_mutate = mutate
map2_dfc(
formula_expr,
names(formula_expr),
~ x |>
my_mutate(!!.y := eval(rlang::parse_expr(.x))) |>
# mutate(!!column_name := eval(rlang::parse_expr(.x))) |>
select(any_of(.y))
) |>
# I could drop this to just result contrasts
add_column(x |> select(-any_of(names(formula_expr))), .before = 1)
}
mutate_ignore_error = function(x, ...){
tryCatch(
{ x |> mutate(...) },
error=function(cond) { x }
)
}
simulate_multinomial_logit_linear = function(model_input, sd = 0.51){
mu = model_input$X %*% model_input$beta
proportions =
rnorm(length(mu), mu, sd) %>%
matrix(nrow = nrow(model_input$X)) %>%
boot::inv.logit()
apply(1, function(x) x/sum(x)) %>%
t()
rownames(proportions) = rownames(model_input$X)
colnames(proportions) = colnames(model_input$beta )
}
compress_zero_one = function(y){
# https://stats.stackexchange.com/questions/48028/beta-regression-of-proportion-data-including-1-and-0
n = length(y)
(y * (n-1) + 0.5) / n
}
# this can be helpful if we want to draw PCA with uncertainty
get_abundance_contrast_draws = function(.data, contrasts){
# Define the variables as NULL to avoid CRAN NOTES
X <- NULL
.value <- NULL
X_random_effect <- NULL
.variable <- NULL
y <- NULL
M <- NULL
khat <- NULL
parameter <- NULL
n_eff <- NULL
R_k_hat <- NULL
.cell_group = .data |> attr(".cell_group")
# Beta
beta_factor_of_interest = .data |> attr("model_input") %$% X |> colnames()
beta =
.data |>
attr("fit") %>%
draws_to_tibble_x_y("beta", "C", "M") |>
pivot_wider(names_from = C, values_from = .value) %>%
setNames(colnames(.)[1:5] |> c(beta_factor_of_interest))
# Abundance
draws = select(beta, -.variable)
# Random effect
if(.data |> attr("model_input") %$% n_random_eff > 0){
beta_random_effect_factor_of_interest = .data |> attr("model_input") %$% X_random_effect |> colnames()
beta_random_effect =
.data |>
attr("fit") %>%
draws_to_tibble_x_y("random_effect", "C", "M")
# Add last component
beta_random_effect =
beta_random_effect |>
bind_rows(
beta_random_effect |>
with_groups(c(C, .chain, .iteration, .draw, .variable ), ~ .x |> summarise(.value = sum(.value))) |>
mutate(.value = -.value, M = beta_random_effect |> pull(M) |> max() + 1)
)
# Reshape
beta_random_effect =
beta_random_effect |>
pivot_wider(names_from = C, values_from = .value) %>%
setNames(colnames(.)[1:5] |> c(beta_random_effect_factor_of_interest))
draws = draws |>
left_join(select(beta_random_effect, -.variable),
by = c("M", ".chain", ".iteration", ".draw")
)
} else {
beta_random_effect_factor_of_interest = ""
}
# Second random effect. IN THE FUTURE THIS WILL BE VECTORISED TO ARBUTRARY GRI+OUING
if(.data |> attr("model_input") %$% n_random_eff > 1){
beta_random_effect_factor_of_interest_2 = .data |> attr("model_input") %$% X_random_effect_2 |> colnames()
beta_random_effect_2 =
.data |>
attr("fit") %>%
draws_to_tibble_x_y("random_effect_2", "C", "M")
# Add last component
beta_random_effect_2 =
beta_random_effect_2 |>
bind_rows(
beta_random_effect_2 |>
with_groups(c(C, .chain, .iteration, .draw, .variable ), ~ .x |> summarise(.value = sum(.value))) |>
mutate(.value = -.value, M = beta_random_effect_2 |> pull(M) |> max() + 1)
)
# Reshape
beta_random_effect_2 =
beta_random_effect_2 |>
pivot_wider(names_from = C, values_from = .value) %>%
setNames(colnames(.)[1:5] |> c(beta_random_effect_factor_of_interest_2))
draws = draws |>
left_join(select(beta_random_effect_2, -.variable),
by = c("M", ".chain", ".iteration", ".draw")
)
} else {
beta_random_effect_factor_of_interest_2 = ""
}
# If I have constrasts calculate
if(!is.null(contrasts))
draws =
draws |>
mutate_from_expr_list(contrasts, ignore_errors = FALSE) |>
select(- any_of(c(beta_factor_of_interest, beta_random_effect_factor_of_interest) |> setdiff(contrasts)) )
# Add cell name
draws = draws |>
left_join(
.data |>
attr("model_input") %$%
y %>%
colnames() |>
enframe(name = "M", value = quo_name(.cell_group)),
by = "M"
) %>%
select(!!.cell_group, everything())
# If no contrasts of interest just return an empty data frame
if(ncol(draws)==5) return(draws |> distinct(M, !!.cell_group))
# Get convergence
convergence_df =
.data |>
attr("fit") |>
summary_to_tibble("beta", "C", "M") |>
# Add cell name
left_join(
.data |>
attr("model_input") %$%
y %>%
colnames() |>
enframe(name = "M", value = quo_name(.cell_group)),
by = "M"
) |>
# factor names
left_join(
beta_factor_of_interest |>
enframe(name = "C", value = "parameter"),
by = "C"
)
if ("Rhat" %in% colnames(convergence_df)) {
convergence_df <- rename(convergence_df, R_k_hat = Rhat)
} else if ("khat" %in% colnames(convergence_df)) {
convergence_df <- rename(convergence_df, R_k_hat = khat)
}
convergence_df =
convergence_df |>
select(!!.cell_group, parameter, any_of(c("n_eff", "R_k_hat"))) |>
suppressWarnings()
draws |>
pivot_longer(-c(1:5), names_to = "parameter", values_to = ".value") |>
# Attach convergence if I have no contrasts
left_join(convergence_df, by = c(quo_name(.cell_group), "parameter")) |>
# Reorder because pivot long is bad
mutate(parameter = parameter |> fct_relevel(colnames(draws)[-c(1:5)])) |>
arrange(parameter)
}
#' @importFrom forcats fct_relevel
#' @noRd
get_variability_contrast_draws = function(.data, contrasts){
# Define the variables as NULL to avoid CRAN NOTES
XA <- NULL
.value <- NULL
y <- NULL
M <- NULL
khat <- NULL
parameter <- NULL
n_eff <- NULL
R_k_hat <- NULL
.cell_group = .data |> attr(".cell_group")
variability_factor_of_interest = .data |> attr("model_input") %$% XA |> colnames()
draws =
.data |>
attr("fit") %>%
draws_to_tibble_x_y("alpha_normalised", "C", "M") |>
# We want variability, not concentration
mutate(.value = -.value) |>
pivot_wider(names_from = C, values_from = .value) %>%
setNames(colnames(.)[1:5] |> c(variability_factor_of_interest)) |>
select( -.variable)
# If I have constrasts calculate
if (!is.null(contrasts))
draws <- mutate_from_expr_list(draws, contrasts, ignore_errors = TRUE)
draws = draws |>
# Add cell name
left_join(
.data |> attr("model_input") %$%
y %>%
colnames() |>
enframe(name = "M", value = quo_name(.cell_group)),
by = "M"
) %>%
select(!!.cell_group, everything())
# If no contrasts of interest just return an empty data frame
if(ncol(draws)==5) return(draws |> distinct(M, !!.cell_group))
# Get convergence
convergence_df =
.data |>
attr("fit") |>
summary_to_tibble("alpha_normalised", "C", "M") |>
# Add cell name
left_join(
.data |>
attr("model_input") %$%
y %>%
colnames() |>
enframe(name = "M", value = quo_name(.cell_group)),
by = "M"
) |>
# factor names
left_join(
variability_factor_of_interest |>
enframe(name = "C", value = "parameter"),
by = "C"
)
if ("Rhat" %in% colnames(convergence_df)) {
convergence_df <- rename(convergence_df, R_k_hat = Rhat)
} else if ("khat" %in% colnames(convergence_df)) {
convergence_df <- rename(convergence_df, R_k_hat = khat)
}
convergence_df =
convergence_df |>
select(!!.cell_group, parameter, any_of(c("n_eff", "R_k_hat"))) |>
suppressWarnings()
draws |>
pivot_longer(-c(1:5), names_to = "parameter", values_to = ".value") |>
# Attach convergence if I have no contrasts
left_join(convergence_df, by = c(quo_name(.cell_group), "parameter")) |>
# Reorder because pivot long is bad
mutate(parameter = parameter |> fct_relevel(colnames(draws)[-c(1:5)])) |>
arrange(parameter)
}
#' @importFrom tibble deframe
#'
#' @noRd
replicate_data = function(.data,
formula_composition = NULL,
formula_variability = NULL,
new_data = NULL,
number_of_draws = 1,
mcmc_seed = sample(1e5, 1),
cores = detectCores()){
# Select model based on noise model
noise_model = attr(.data, "noise_model")
model_input = attr(.data, "model_input")
.sample = attr(.data, ".sample")
.cell_group = attr(.data, ".cell_group")
# Composition
if(is.null(formula_composition)) formula_composition = .data |> attr("formula_composition")
# New data
if(new_data |> is.null())
new_data =
.data |>
select(count_data) |>
unnest(count_data) |>
distinct()
# If seurat
else if(new_data |> is("Seurat")) new_data = new_data[[]]
# Just subset
new_data = new_data |> .subset(!!.sample)
# Check if the input new data is not suitable
if(!parse_formula(formula_composition) %in% colnames(new_data) |> all())
stop("sccomp says: your `new_data` might be malformed. It might have the covariate columns with multiple values for some element of the \"%s\" column. As a generic example, a sample identifier (\"Sample_123\") might be associated with multiple treatment values, or age values.")
# Match factors with old data
nrow_new_data = nrow(new_data)
new_exposure = new_data |>
nest(data = -!!.sample) |>
mutate(exposure = map_dbl(
data,
~{
if ("count" %in% colnames(.x)) sum(.x$count)
else 5000
})) |>
select(!!.sample, exposure) |>
deframe() |>
as.array()
# Update data, merge with old data because
# I need the same ordering of the design matrix
new_data =
# Old data
.data |>
select(count_data) |>
unnest(count_data) |>
select(-count) |>
select(new_data |> as_tibble() |> colnames() |> any_of()) |>
distinct() |>
# Change sample names to make unique
mutate(dummy = "OLD") |>
tidyr::unite(!!.sample, c(!!.sample, dummy), sep="___") |>
# New data
bind_rows(
new_data |> as_tibble()
)
new_X =
new_data |>
get_design_matrix(
# Drop random intercept
formula_composition |>
as.character() |>
str_remove_all("\\+ ?\\(.+\\|.+\\)") |>
paste(collapse="") |>
as.formula(),
!!.sample
) |>
tail(nrow_new_data) %>%
# Remove columns that are not in the original design matrix
.[,colnames(.) %in% colnames(model_input$X), drop=FALSE]
X_which =
colnames(new_X) |>
match(
model_input$X %>%
colnames()
) |>
na.omit() |>
as.array()
# Variability
if(is.null(formula_variability)) formula_variability = .data |> attr("formula_variability")
new_Xa =
new_data |>
get_design_matrix(
# Drop random intercept
formula_variability |>
as.character() |>
str_remove_all("\\+ ?\\(.+\\|.+\\)") |>
paste(collapse="") |>
as.formula(),
!!.sample
) |>
tail(nrow_new_data) %>%
# Remove columns that are not in the original design matrix
.[,colnames(.) %in% colnames(model_input$Xa), drop=FALSE]
XA_which =
colnames(new_Xa) |>
match(
model_input %$%
Xa %>%
colnames()
) |>
na.omit() |>
as.array()
# If I want to replicate data with intercept and I don't have intercept in my fit
create_intercept =
model_input %$% intercept_in_design |> not() &
"(Intercept)" %in% colnames(new_X)
if(create_intercept) warning("sccomp says: your estimated model is intercept free, while your desired replicated data do have an intercept term. The intercept estimate will be calculated averaging your first factor in your formula ~ 0 + <factor>. If you don't know the meaning of this warning, this is likely undesired, and please reconsider your formula for replicate_data()")
# Original grouping
original_grouping_names = .data |> attr("formula_composition") |> formula_to_random_effect_formulae() |> pull(grouping)
# Random intercept
random_effect_elements = parse_formula_random_effect(formula_composition)
random_effect_grouping =
new_data %>%
get_random_effect_design2(
!!.sample,
formula_composition
)
# if(random_effect_elements |> nrow() |> equals(0)) {
#
# }
if((random_effect_grouping$grouping %in% original_grouping_names[1]) |> any()) {
# HAVE TO DEBUG
new_X_random_effect =
random_effect_grouping |>
filter(grouping==original_grouping_names[1]) |>
mutate(design_matrix = map(
design,
~ ..1 |>
select(!!.sample, group___label, value) |>
pivot_wider(names_from = group___label, values_from = value) |>
mutate(across(everything(), ~ .x |> replace_na(0)))
)) |>
# Merge
pull(design_matrix) |>
_[[1]] |>
as_matrix(rownames = quo_name(.sample)) |>
tail(nrow_new_data)
# I HAVE TO KEEP GROUP NAME IN COLUMN NAME
X_random_effect_which =
colnames(new_X_random_effect) |>
match(
model_input %$%
X_random_effect %>%
colnames()
) |>
as.array()
# Check if I have column in the new design that are not in the old one
missing_columns = new_X_random_effect |> colnames() |> setdiff(colnames(model_input$X_random_effect))
if(missing_columns |> length() > 0)
stop(glue("sccomp says: the columns in the design matrix {paste(missing_columns, collapse= ' ,')} are missing from the design matrix of the estimate-input object. Please make sure your new model is a sub-model of your estimated one."))
}
else{
X_random_effect_which = array()[0]
new_X_random_effect = matrix(rep(0, nrow_new_data))[,0, drop=FALSE]
}
if((random_effect_grouping$grouping %in% original_grouping_names[2]) |> any()){
# HAVE TO DEBUG
new_X_random_effect_2 =
random_effect_grouping |>
filter(grouping==original_grouping_names[2]) |>
mutate(design_matrix = map(
design,
~ ..1 |>
select(!!.sample, group___label, value) |>
pivot_wider(names_from = group___label, values_from = value) |>
mutate(across(everything(), ~ .x |> replace_na(0)))
)) |>
# Merge
pull(design_matrix) |>
_[[1]] |>
as_matrix(rownames = quo_name(.sample)) |>
tail(nrow_new_data)
# DUPLICATE
X_random_effect_which_2 =
colnames(new_X_random_effect_2) |>
match(
model_input %$%
X_random_effect_2 %>%
colnames()
) |>
as.array()
}
else{
X_random_effect_which_2 = array()[0]
new_X_random_effect_2 = matrix(rep(0, nrow_new_data))[,0, drop=FALSE]
}
# New X
model_input$X_original = model_input$X
model_input$X = new_X
model_input$Xa = new_Xa
model_input$N_original = model_input$N
model_input$N = nrow_new_data
model_input$exposure = new_exposure
model_input$X_random_effect = new_X_random_effect
model_input$X_random_effect_2 = new_X_random_effect_2
model_input$ncol_X_random_eff_new = ncol(new_X_random_effect) |> c(ncol(new_X_random_effect_2))
number_of_draws_in_the_fit = attr(.data, "fit") |> get_output_samples()
# To avoid error in case of a NULL posterior sample
number_of_draws = min(number_of_draws, number_of_draws_in_the_fit)
# Load model
mod_rng = load_model("glm_multi_beta_binomial_generate_data", threads = cores)
# Generate quantities
mod_rng |> sample_safe(
generate_quantities_fx,
attr(.data, "fit")$draws(format = "matrix")[
sample(seq_len(number_of_draws_in_the_fit), size=number_of_draws),, drop=FALSE
],
data = model_input |> c(list(
# Add subset of coefficients
length_X_which = length(X_which),
length_XA_which = length(XA_which),
X_which = X_which,
XA_which = XA_which,
# Random intercept
X_random_effect_which = X_random_effect_which,
X_random_effect_which_2 = X_random_effect_which_2,
length_X_random_effect_which = length(X_random_effect_which) |> c(length(X_random_effect_which_2)),
# Should I create intercept for generate quantities
create_intercept = create_intercept
)),
seed = mcmc_seed,
threads_per_chain = 1
)
}
get_model_from_data = function(file_compiled_model, model_code){
if(file.exists(file_compiled_model))
readRDS(file_compiled_model)
else {
model_generate = stan_model(model_code = model_code)
model_generate %>% saveRDS(file_compiled_model)
model_generate
}
}
add_formula_columns = function(.data, .original_data, .sample, formula_composition){
.sample = enquo(.sample)
formula_elements = parse_formula(formula_composition)
# If no formula return the input
if(length(formula_elements) == 0) return(.data)
# Get random intercept
.grouping_for_random_effect = parse_formula_random_effect(formula_composition) |> pull(grouping) |> unique()
data_frame_formula =
.original_data %>%
as_tibble() |>
select( !!.sample, formula_elements, any_of(.grouping_for_random_effect) ) %>%
distinct()
.data |>
left_join(data_frame_formula, by = quo_name(.sample) )
}
#' chatGPT - Remove Specified Regex Pattern from Each String in a Vector
#'
#' This function takes a vector of strings and a regular expression pattern.
#' It removes occurrences of the pattern from each string, except where the pattern
#' is found inside backticks. The function returns a vector of cleaned strings.
#'
#' @param text_vector A character vector with the strings to be processed.
#' @param regex A character string containing a regular expression pattern to be removed
#' from the text.
#'
#' @return A character vector with the regex pattern removed from each string.
#' Occurrences of the pattern inside backticks are not removed.
#'
#' @examples
#' texts <- c("A string with (some) parentheses and `a (parenthesis) inside` backticks",
#' "Another string with (extra) parentheses")
#' cleaned_texts <- str_remove_all_ignoring_if_inside_backquotes(texts, "\\(")
#' print(cleaned_texts)
#'
#' @noRd
str_remove_all_ignoring_if_inside_backquotes <- function(text_vector, regex) {
# Nested function to handle regex removal for a single string
remove_regex_chars <- function(text, regex) {
inside_backticks <- FALSE
result <- ""
skip <- 0
chars <- strsplit(text, "")[[1]]
for (i in seq_along(chars)) {
if (skip > 0) {
skip <- skip - 1
next
}
char <- chars[i]
if (char == "`") {
inside_backticks <- !inside_backticks
result <- paste0(result, char)
} else if (!inside_backticks) {
# Check the remaining text against the regex
remaining_text <- paste(chars[i:length(chars)], collapse = "")
match <- regexpr(regex, remaining_text)
if (attr(match, "match.length") > 0 && match[1] == 1) {
# Skip the length of the matched text
skip <- attr(match, "match.length") - 1
next
} else {
result <- paste0(result, char)
}
} else {
result <- paste0(result, char)
}
}
return(result)
}
# Apply the function to each element in the vector
sapply(text_vector, remove_regex_chars, regex)
}
#' chatGPT - Split Each String in a Vector by a Specified Regex Pattern
#'
#' This function takes a vector of strings and a regular expression pattern. It splits
#' each string based on the pattern, except where the pattern is found inside backticks.
#' The function returns a list, with each element being a vector of the split segments
#' of the corresponding input string.
#'
#' @param text_vector A character vector with the strings to be processed.
#' @param regex A character string containing a regular expression pattern used for splitting
#' the text.
#'
#' @return A list of character vectors. Each list element corresponds to an input string
#' from `text_vector`, split according to `regex`, excluding occurrences inside backticks.
#'
#' @examples
#' texts <- c("A string with, some, commas, and `a, comma, inside` backticks",
#' "Another string, with, commas")
#' split_texts <- split_regex_chars_from_vector(texts, ",")
#' print(split_texts)
#'
#' @noRd
str_split_ignoring_if_inside_backquotes <- function(text_vector, regex) {
# Nested function to handle regex split for a single string
split_regex_chars <- function(text, regex) {
inside_backticks <- FALSE
result <- c()
current_segment <- ""
chars <- strsplit(text, "")[[1]]
for (i in seq_along(chars)) {
char <- chars[i]
if (char == "`") {
inside_backticks <- !inside_backticks
current_segment <- paste0(current_segment, char)
} else if (!inside_backticks) {
# Check the remaining text against the regex
remaining_text <- paste(chars[i:length(chars)], collapse = "")
match <- regexpr(regex, remaining_text)
if (attr(match, "match.length") > 0 && match[1] == 1) {
# Add current segment to result and start a new segment
result <- c(result, current_segment)
current_segment <- ""
# Skip the length of the matched text
skip <- attr(match, "match.length") - 1
i <- i + skip
} else {
current_segment <- paste0(current_segment, char)
}
} else {
current_segment <- paste0(current_segment, char)
}
}
# Add the last segment to the result
result <- c(result, current_segment)
return(result)
}
# Apply the function to each element in the vector
lapply(text_vector, split_regex_chars, regex)
}
#' chatGPT - Check for Valid Column Names in Tidyverse Context
#'
#' This function checks if each given column name in a vector contains only valid characters
#' (letters, numbers, periods, and underscores) and does not start with a digit
#' or an underscore, which are the conditions for a valid column name in `tidyverse`.
#'
#' @param column_names A character vector representing the column names to be checked.
#'
#' @return A logical vector: `TRUE` for each column name that contains only valid characters
#' and does not start with a digit or an underscore; `FALSE` otherwise.
#'
#' @examples
#' contains_only_valid_chars_for_column(c("valid_column", "invalid column", "valid123",
#' "123startWithNumber", "_startWithUnderscore"))
#'
#' @noRd
contains_only_valid_chars_for_column <- function(column_names) {
# Function to check a single column name
check_validity <- function(column_name) {
# Regex pattern for valid characters (letters, numbers, periods, underscores)
valid_char_pattern <- "[A-Za-z0-9._]"
# Check if all characters in the string match the valid pattern
all_chars_valid <- stringr::str_detect(column_name, paste0("^", valid_char_pattern, "+$"))
# Check for leading digits or underscores
starts_with_digit_or_underscore <- stringr::str_detect(column_name, "^[0-9_]")
return(all_chars_valid && !starts_with_digit_or_underscore)
}
# Apply the check to each element of the vector
sapply(column_names, check_validity)
}
#' chatGPT - Intelligently Remove Surrounding Brackets from Each String in a Vector
#'
#' This function processes each string in a vector and removes surrounding brackets if the content
#' within the brackets includes any of '+', '-', or '*', and if the brackets are not
#' within backticks. This is particularly useful for handling formula-like strings.
#'
#' @param text A character vector with strings from which the brackets will be removed based on
#' specific conditions.
#'
#' @return A character vector with the specified brackets removed from each string.
#'
#' @examples
#' str_remove_brackets_from_formula_intelligently(c("This is a test (with + brackets)", "`a (test) inside` backticks", "(another test)"))
#'
#' @noRd
str_remove_brackets_from_formula_intelligently <- function(text) {
# Function to remove brackets from a single string
remove_brackets_single <- function(s) {
inside_backticks <- FALSE
bracket_depth <- 0
valid_bracket_content <- FALSE
result <- ""
bracket_content <- ""
chars <- strsplit(s, "")[[1]]
for (i in seq_along(chars)) {
char <- chars[i]
if (char == "`") {
inside_backticks <- !inside_backticks
}
if (!inside_backticks) {
if (char == "(") {
bracket_depth <- bracket_depth + 1
if (bracket_depth > 1) {
bracket_content <- paste0(bracket_content, char)
}
next
} else if (char == ")") {
bracket_depth <- bracket_depth - 1
if (bracket_depth == 0) {
if (grepl("[\\+\\-\\*]", bracket_content)) {
result <- paste0(result, bracket_content)
} else {
result <- paste0(result, "(", bracket_content, ")")
}
bracket_content <- ""
next
}
}
if (bracket_depth >= 1) {
bracket_content <- paste0(bracket_content, char)
} else {
result <- paste0(result, char)
}
} else {
result <- paste0(result, char)
}
}
return(result)
}
# Apply the function to each element in the vector
sapply(text, remove_brackets_single)
}
# Negation
not = function(is){ !is }
#' Convert array of quosure (e.g. c(col_a, col_b)) into character vector
#'
#' @keywords internal
#' @noRd
#'
#' @importFrom rlang quo_name
#' @importFrom rlang quo_squash
#'
#' @param v A array of quosures (e.g. c(col_a, col_b))
#'
#' @return A character vector
quo_names <- function(v) {
v = quo_name(quo_squash(v))
gsub('^c\\(|`|\\)$', '', v) |>
strsplit(', ') |>
unlist()
}
#' Add class to abject
#'
#' @keywords internal
#' @noRd
#'
#' @param var A tibble
#' @param name A character name of the attribute
#'
#' @return A tibble with an additional attribute
add_class = function(var, name) {
if(!name %in% class(var)) class(var) <- append(class(var),name, after = 0)
var
}
#' Get Output Samples from a Stan Fit Object
#'
#' This function retrieves the number of output samples from a Stan fit object,
#' supporting different methods (MHC and Variational) based on the available data within the object.
#'
#' @param fit A `stanfit` object, which is the result of fitting a model via Stan.
#' @return The number of output samples used in the Stan model.
#' Returns from MHC if available, otherwise from Variational inference.
#' @examples
#' # Assuming 'fit' is a stanfit object obtained from running a Stan model
#' print("samples_count = get_output_samples(fit)")
#'
#' @export
#'
get_output_samples = function(fit){
# Check if the output_samples field is present in the metadata of the fit object
# This is generally available when the model is fit using MHC (Markov chain Monte Carlo)
if(!is.null(fit$metadata()$output_samples)) {
# Return the output_samples from the metadata
fit$metadata()$output_samples
}
# If the output_samples field is not present, check for iter_sampling
# This occurs typically when the model is fit using Variational inference methods
else if(!is.null(fit$metadata()$iter_sampling)) {
# Return the iter_sampling from the metadata
fit$metadata()$iter_sampling
}
else
fit$metadata()$num_psis_draws
}
#' Load, Compile, and Cache a Stan Model
#'
#' This function attempts to load a precompiled Stan model using the `instantiate` package.
#' If the model is not found in the cache or force recompilation is requested, it will locate
#' the Stan model file within the `sccomp` package, compile it using `cmdstanr`, and save the
#' compiled model to the cache directory for future use.
#'
#' @param name A character string representing the name of the Stan model (without the `.stan` extension).
#' @param cache_dir A character string representing the path to the cache directory where compiled models are saved.
#' Defaults to `sccomp_stan_models_cache_dir`.
#' @param force A logical value. If `TRUE`, the model will be recompiled even if it exists in the cache.
#' Defaults to `FALSE`.
#' @param threads An integer specifying the number of threads to use for compilation.
#' Defaults to `1`.
#'
#' @return A compiled Stan model object from `cmdstanr`.
#'
#' @importFrom instantiate stan_package_model
#' @importFrom instantiate stan_package_compile
#'
#' @noRd
#'
#' @examples
#' \donttest{
#' model <- load_model("glm_multi_beta_binomial_", "~/cache", force = FALSE, threads = 1)
#' }
load_model <- function(name, cache_dir = sccomp_stan_models_cache_dir, force=FALSE, threads = 1) {
# tryCatch({
# # Attempt to load a precompiled Stan model using the instantiate package
# instantiate::stan_package_model(
# name = name,
# package = "sccomp"
# )
# }, error = function(e) {
# Try to load the model from cache
# RDS compiled model
cache_dir |> dir.create(showWarnings = FALSE, recursive = TRUE)
cache_file <- file.path(cache_dir, paste0(name, ".rds"))
# .STAN raw model
stan_model_path <- system.file("stan", paste0(name, ".stan"), package = "sccomp")
if (file.exists(cache_file) & !force) {
message("Loading model from cache...")
return(readRDS(cache_file))
}
# If loading the precompiled model fails, find the Stan model file within the package
message("Precompiled model not found. Compiling the model...")
# Compile the Stan model using cmdstanr with threading support enabled
instantiate::stan_package_compile(
stan_model_path,
cpp_options = list(stan_threads = TRUE),
force_recompile = TRUE,
threads = threads,
dir = system.file("stan", package = "sccomp")
)
mod = instantiate::stan_package_model(
name = name,
package = "sccomp",
compile = TRUE,
cpp_options = list(stan_threads = TRUE)
) |> suppressWarnings()
# Save the compiled model object to cache
saveRDS(mod, file = cache_file)
message("Model compiled and saved to cache successfully.")
return(mod)
# })
}
#' Check and Install cmdstanr and CmdStan
#'
#' This function checks if the `cmdstanr` package and CmdStan are installed.
#' If they are not installed, it installs them automatically in non-interactive sessions
#' or asks for permission to install them in interactive sessions.
#'
#' @importFrom instantiate stan_cmdstan_exists
#' @importFrom utils install.packages
#' @importFrom utils menu
#' @return NULL
#'
#' @noRd
check_and_install_cmdstanr <- function() {
# Check if cmdstanr is installed
if (!requireNamespace("cmdstanr", quietly = TRUE)) {
clear_stan_model_cache()
stop(
"cmdstanr is required to proceed.\n\n",
"Step 1: Please install the R package 'cmdstanr' using the following command:\n",
"install.packages(\"cmdstanr\", repos = c(\"https://stan-dev.r-universe.dev/\", getOption(\"repos\")))\n",
"Note: 'cmdstanr' is not available on CRAN.\n\n",
"Step 2: After installing 'cmdstanr', you can install CmdStan by running the following commands:\n",
"cmdstanr::check_cmdstan_toolchain(fix = TRUE)\n",
"cmdstanr::install_cmdstan()\n",
"This will install the latest version of CmdStan. For more information, visit:\n",
"https://mc-stan.org/users/interfaces/cmdstan"
)
}
# Check if CmdStan is installed
if (!instantiate::stan_cmdstan_exists()) {
clear_stan_model_cache()
stop(
"cmdstan is required to proceed.\n\n",
"You can install CmdStan by running the following command:\n",
"cmdstanr::check_cmdstan_toolchain(fix = TRUE)\n",
"cmdstanr::install_cmdstan()\n",
"This will install the latest version of CmdStan. For more information, visit:\n",
"https://mc-stan.org/users/interfaces/cmdstan"
)
}
}
drop_environment <- function(obj) {
# Check if the object has an environment
if (!is.null(environment(obj))) {
environment(obj) <- new.env()
}
return(obj)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.