R/utilities.R

Defines functions summary_to_tibble draws_to_tibble_x_y vb_iterative as_matrix ifelse_pipe parse_formula_random_effect formula_to_random_effect_formulae parse_formula add_attr not st gt

# Define global variable
sccomp_stan_models_cache_dir = file.path(path.expand("~"), ".sccomp_models", packageVersion("sccomp"))

# Greater than
gt = function(a, b){	a > b }

# Smaller than
st = function(a, b){	a < b }

# Negation
not = function(is){	!is }

#' Add attribute to abject
#'
#' @keywords internal
#' @noRd
#'
#'
#' @param var A tibble
#' @param attribute An object
#' @param name A character name of the attribute
#'
#' @return A tibble with an additional attribute
add_attr = function(var, attribute, name) {
  attr(var, name) <- attribute
  var
}


#' Formula parser
#'
#' @param fm A formula
#'
#' @importFrom stringr str_subset
#' @importFrom magrittr extract2
#' @importFrom stats terms
#'
#' @return A character vector
#'
#' @keywords internal
#' @noRd
parse_formula <- function(fm) {
  stopifnot("The formula must be of the kind \"~ factors\" " = attr(terms(fm), "response") == 0)
  
  as.character(attr(terms(fm), "variables")) |>
    str_subset("\\|", negate = TRUE) %>%
    
    # Does not work the following
    # |>
    # extract2(-1)
    .[-1]
}


#' Formula parser
#'
#' @param fm A formula
#'
#' @importFrom stringr str_subset
#' @importFrom stringr str_split
#' @importFrom stringr str_remove_all
#' @importFrom rlang set_names
#' @importFrom purrr map_dfr
#' @importFrom stringr str_trim
#'
#' @importFrom magrittr extract2
#'
#' @return A character vector
#'
#' @keywords internal
#' @noRd
formula_to_random_effect_formulae <- function(fm) {
  
  # Define the variables as NULL to avoid CRAN NOTES
  formula <- NULL
  
  stopifnot("The formula must be of the kind \"~ factors\" " = attr(terms(fm), "response") == 0)
  
  random_effect_elements =
    as.character(attr(terms(fm), "variables")) |>
    
    # Select random intercept part
    str_subset("\\|")
  
  if(length(random_effect_elements) > 0){
    
    random_effect_elements |>
      
      # Divide grouping from factors
      str_split("\\|") |>
      
      # Set name
      map_dfr(~ .x |> set_names(c("formula", "grouping"))) |>
      
      # Create formula
      mutate(formula = map(formula, ~ formula(glue("~ {.x}")))) |>
      mutate(grouping = grouping |> str_trim())
    
  }
  
  else
    tibble(`formula` = list(), grouping = character())
  
}

#' Formula parser
#'
#' @param fm A formula
#'
#' @importFrom stringr str_subset
#' @importFrom stringr str_split
#' @importFrom stringr str_remove_all
#' @importFrom rlang set_names
#' @importFrom purrr map_dfr
#'
#' @importFrom magrittr extract2
#'
#' @return A character vector
#'
#' @keywords internal
#' @noRd
parse_formula_random_effect <- function(fm) {
  
  # Define the variables as NULL to avoid CRAN NOTES
  formula <- NULL
  
  stopifnot("The formula must be of the kind \"~ factors\" " = attr(terms(fm), "response") == 0)
  
  random_effect_elements =
    as.character(attr(terms(fm), "variables")) |>
    
    # Select random intercept part
    str_subset("\\|")
  
  if(length(random_effect_elements) > 0){
    
    formula_to_random_effect_formulae(fm) |>
      
      # Divide factors
      mutate(factor = map(
        formula,
        ~
          # Attach intercept
          .x |>
          terms() |>
          attr("intercept") |>
          str_replace("^1$", "(Intercept)") |>
          str_subset("0", negate = TRUE) |>
          
          # Attach variables
          c(
            .x |>
              terms() |>
              attr("variables") |>
              as.character() |>
              str_split("\\+") |>
              as.character() %>%
              .[-1]
          )
      )) |>
      unnest(factor)
    
  }
  
  else
    tibble(factor = character(), grouping = character())
  
}

#' Get matrix from tibble
#'
#' @import dplyr
#'
#' @keywords internal
#' @noRd
#'
#' @import dplyr
#' @importFrom purrr as_mapper
#'
#' @param .x A tibble
#' @param .p A boolean
#' @param .f1 A function
#' @param .f2 A function
#'
#' @return A tibble
ifelse_pipe = function(.x, .p, .f1, .f2 = NULL) {
  switch(.p %>% `!` %>% sum(1),
         as_mapper(.f1)(.x),
         if (.f2 %>% is.null %>% `!`)
           as_mapper(.f2)(.x)
         else
           .x)
  
}

#' @importFrom tidyr gather
#' @importFrom magrittr set_rownames
#'
#' @keywords internal
#' @noRd
#'
#' @param tbl A tibble
#' @param rownames A character string of the rownames
#'
#' @return A matrix
as_matrix <- function(tbl, rownames = NULL) {
  
  # Define the variables as NULL to avoid CRAN NOTES
  variable <- NULL
  
  tbl %>%
    
    ifelse_pipe(
      tbl %>%
        ifelse_pipe(!is.null(rownames),		~ .x %>% dplyr::select(-contains(rownames))) %>%
        summarise_all(class) %>%
        gather(variable, class) %>%
        pull(class) %>%
        unique() %>%
        `%in%`(c("numeric", "integer")) %>% not() %>% any(),
      ~ {
        warning("to_matrix says: there are NON-numerical columns, the matrix will NOT be numerical")
        .x
      }
    ) %>%
    as.data.frame() %>%
    
    # Deal with rownames column if present
    ifelse_pipe(!is.null(rownames),
                ~ .x  %>%
                  set_rownames(tbl %>% pull(!!rownames)) %>%
                  select(-!!rownames)) %>%
    
    # Convert to matrix
    as.matrix()
}

#' vb_iterative
#'
#' @description Runs iteratively variational bayes until it suceeds
#'
#'
#' @keywords internal
#' @noRd
#'
#' @param model A Stan model
#' @param output_samples An integer of how many samples from posteriors
#' @param iter An integer of how many max iterations
#' @param tol_rel_obj A real
#' @param additional_parameters_to_save A character vector
#' @param data A data frame
#' @param seed An integer
#' @param ... List of paramaters for vb function of Stan
#'
#' @return A Stan fit object
#'
vb_iterative = function(model,
                        output_samples,
                        iter,
                        tol_rel_obj,
                        additional_parameters_to_save = c(),
                        data,
                        output_dir = output_dir,
                        seed, 
                        init = "random",
                        inference_method,
                        cores = 1, 
                        verbose = TRUE,
                        psis_resample = FALSE,
                        ...) {
  res = NULL
  i = 0
  while  (is.null(res) & i < 5) {
    res = tryCatch({
      
      if(inference_method=="pathfinder")
        my_res = model |> 
          sample_safe(
            pathfinder_fx,
            data = data,
            tol_rel_obj = tol_rel_obj,
            output_dir = output_dir,
            seed = seed+i,
            # init = init,
            num_paths=50, 
            num_threads = cores,
            single_path_draws = output_samples / 50 ,
            max_lbfgs_iters=100, 
            history_size = 100, 
            show_messages = verbose,
            psis_resample = psis_resample,
            ...
          )
      
      else if(inference_method=="variational")
        my_res = model |> 
          sample_safe(
            variational_fx,
            data = data,
            output_samples = output_samples,
            iter = iter,
            tol_rel_obj = tol_rel_obj,
            output_dir = output_dir,
            seed = seed+i,
            init = init,
            show_messages = verbose,
            threads = cores,
            ...
          )
      
      boolFalse <- TRUE
      return(my_res)
    },
    error = function(e) {
      
      writeLines(sprintf("Further attempt with Variational Bayes: %s", e))     
      
      return(NULL)
    },
    finally = {
    })
    i = i + 1
  }
  
  if(is.null(res)) stop(sprintf("sccomp says: variational Bayes did not converge after %s attempts. Please use variational_inference = FALSE for a HMC fitting.", i))
  
  return(res)
}




#' draws_to_tibble_x_y
#'
#' @importFrom tidyr pivot_longer
#' @importFrom rlang :=
#'
#' @param fit A fit object
#' @param par A character vector. The parameters to extract.
#' @param x A character. The first index.
#' @param y A character. The first index.
#'
#' @keywords internal
#' @noRd
draws_to_tibble_x_y = function(fit, par, x, y, number_of_draws = NULL) {
  
  # Define the variables as NULL to avoid CRAN NOTES
  dummy <- NULL
  .variable <- NULL
  .chain <- NULL
  .iteration <- NULL
  .draw <- NULL
  .value <- NULL

  
  fit$draws(variables = par, format = "draws_df") %>%
    mutate(.iteration = seq_len(n())) %>%
    
    pivot_longer(
      names_to = "parameter", # c( ".chain", ".variable", x, y),
      cols = contains(par),
      #names_sep = "\\.?|\\[|,|\\]|:",
      # names_ptypes = list(
      #   ".variable" = character()),
      values_to = ".value"
    ) %>%
    tidyr::extract(parameter, c(".chain", ".variable", x, y), "([1-9]+)?\\.?([a-zA-Z0-9_\\.]+)\\[([0-9]+),([0-9]+)") |> 
    
    # Warning message:
    # Expected 5 pieces. Additional pieces discarded
    suppressWarnings() %>%
    
    mutate(
      !!as.symbol(x) := as.integer(!!as.symbol(x)),
      !!as.symbol(y) := as.integer(!!as.symbol(y))
    ) %>%
    arrange(.variable, !!as.symbol(x), !!as.symbol(y), .chain) %>%
    group_by(.variable, !!as.symbol(x), !!as.symbol(y)) %>%
    mutate(.draw = seq_len(n())) %>%
    ungroup() %>%
    select(!!as.symbol(x), !!as.symbol(y), .chain, .iteration, .draw ,.variable ,     .value) %>%
    filter(.variable == par)
  
}


#' @importFrom tidyr separate
#' @importFrom purrr when
#' 
#' @param fit A fit object from a statistical model, from the 'rstan' package.
#' @param par A character vector specifying the parameters to extract from the fit object.
#' @param x A character string specifying the first index in the parameter names.
#' @param y A character string specifying the second index in the parameter names (optional).
#' @param probs A numerical vector specifying the quantiles to extract.
#' 
#' @keywords internal
#' @noRd
summary_to_tibble = function(fit, par, x, y = NULL, probs = c(0.025, 0.25, 0.50, 0.75, 0.975)) {
  
  # Avoid bug
  #if(fit@stan_args[[1]]$method %>% is.null) fit@stan_args[[1]]$method = "hmc"
  
  summary = 
    fit$summary(variables = par, "mean", ~quantile(.x, probs = probs,  na.rm=TRUE), "rhat", "ess_bulk", "ess_tail") %>%
    rename(.variable = variable ) %>%
    
    when(
      is.null(y) ~ (.) %>% tidyr::separate(col = .variable,  into = c(".variable", x, y), sep="\\[|,|\\]", convert = TRUE, extra="drop"),
      ~ (.) %>% tidyr::separate(col = .variable,  into = c(".variable", x, y), sep="\\[|,|\\]", convert = TRUE, extra="drop")
    )
  
  # # summaries are returned only for HMC
  # if(!"n_eff" %in% colnames(summary)) summary = summary |> mutate(n_eff = NA)
  # if(!"R_k_hat" %in% colnames(summary)) summary = summary |> mutate(R_k_hat = NA)
  
  summary
  
}

#' @importFrom rlang :=
label_deleterious_outliers = function(.my_data){
  
  # Define the variables as NULL to avoid CRAN NOTES
  .count <- NULL
  `95%` <- NULL
  `5%` <- NULL
  X <- NULL
  iteration <- NULL
  outlier_above <- NULL
  slope <- NULL
  is_group_right <- NULL
  outlier_below <- NULL
  
  .my_data %>%
    
    # join CI
    mutate(outlier_above = !!.count > `95%`) %>%
    mutate(outlier_below = !!.count < `5%`) %>%
    
    # Mark if on the right of the factor scale
    mutate(is_group_right = !!as.symbol(colnames(X)[2]) > mean( !!as.symbol(colnames(X)[2]) )) %>%
    
    # Check if outlier might be deleterious for the statistics
    mutate(
      !!as.symbol(sprintf("deleterious_outlier_%s", iteration)) :=
        (outlier_above & slope > 0 & is_group_right)  |
        (outlier_below & slope > 0 & !is_group_right) |
        (outlier_above & slope < 0 & !is_group_right) |
        (outlier_below & slope < 0 & is_group_right)
    ) %>%
    
    select(-outlier_above, -outlier_below, -is_group_right)
  
}

#' @importFrom readr write_file
fit_model = function(
    data_for_model, model_name, censoring_iteration = 1, cores = detectCores(), quantile = 0.95,
    warmup_samples = 300, approximate_posterior_inference = NULL, inference_method, verbose = TRUE,
    seed , pars = c("beta", "alpha", "prec_coeff","prec_sd"), output_samples = NULL, chains=NULL, max_sampling_iterations = 20000, 
    output_directory = "sccomp_draws_files",
    ...
)
{
  
  
  # # if analysis approximated
  # # If posterior analysis is approximated I just need enough
  # how_many_posterior_draws_practical = ifelse(approximate_posterior_analysis, 1000, how_many_posterior_draws)
  # additional_parameters_to_save = additional_parameters_to_save %>% c("lambda_log_param", "sigma_raw") %>% unique
  
  
  # Find number of draws
  draws_supporting_quantile = 50
  if(is.null(output_samples)){
    
    output_samples =
      (draws_supporting_quantile/((1-quantile)/2)) %>% # /2 because I have two tails
      max(4000) 
    
    if(output_samples > max_sampling_iterations) {
      # message("sccomp says: the number of draws used to defined quantiles of the posterior distribution is capped to 20K.") # This means that for very low probability threshold the quantile could become unreliable. We suggest to limit the probability threshold between 0.1 and 0.01")
      output_samples = max_sampling_iterations
      
    }}
  
  # Find optimal number of chains
  if(is.null(chains))
    chains =
      find_optimal_number_of_chains(
        how_many_posterior_draws = output_samples,
        warmup = warmup_samples,
        parallelisation_start_penalty = 100
      ) %>%
      min(cores)
  
  # chains = 3
  
  init_list=list(
    prec_coeff = c(5,0),
    prec_sd = 1,
    alpha = matrix(c(rep(5, data_for_model$M), rep(0, (data_for_model$A-1) *data_for_model$M)), nrow = data_for_model$A, byrow = TRUE),
    beta_raw_raw = matrix(0, data_for_model$C , data_for_model$M-1) ,
    mix_p = 0.1 
  )
  
  if(data_for_model$n_random_eff>0){
    init_list$random_effect_raw = matrix(0, data_for_model$ncol_X_random_eff[1]  , data_for_model$M-1)  
    init_list$random_effect_sigma_raw = matrix(0, data_for_model$M-1 , data_for_model$how_many_factors_in_random_design[1])
    init_list$sigma_correlation_factor = array(0, dim = c(
      data_for_model$M-1, 
      data_for_model$how_many_factors_in_random_design[1], 
      data_for_model$how_many_factors_in_random_design[1]
    ))
    
    init_list$random_effect_sigma_mu = 0.5 |> as.array()
    init_list$random_effect_sigma_sigma = 0.2 |> as.array()
    init_list$zero_random_effect = rep(0, size = 1) |> as.array()
    
  } 
  
  if(data_for_model$n_random_eff>1){
    init_list$random_effect_raw_2 = matrix(0, data_for_model$ncol_X_random_eff[2]  , data_for_model$M-1)
    init_list$random_effect_sigma_raw_2 = matrix(0, data_for_model$M-1 , data_for_model$how_many_factors_in_random_design[2])
    init_list$sigma_correlation_factor_2 = array(0, dim = c(
      data_for_model$M-1, 
      data_for_model$how_many_factors_in_random_design[2], 
      data_for_model$how_many_factors_in_random_design[2]
    ))
    
  } 
  
  init = map(1:chains, ~ init_list) %>%
    setNames(as.character(1:chains))
  
  #output_directory = "sccomp_draws_files"
  dir.create(output_directory, showWarnings = FALSE)
  
  # Fit
  mod = load_model(model_name, threads = cores)
  
  # Avoid 0 proportions
  if(data_for_model$is_proportion && min(data_for_model$y_proportion)==0){
    warning("sccomp says: your proportion values include 0. Assuming that 0s derive from a precision threshold (e.g. deconvolution), 0s are converted to the smaller non 0 proportion value.")
    data_for_model$y_proportion[data_for_model$y_proportion==0] =
      min(data_for_model$y_proportion[data_for_model$y_proportion>0])
  }
  
  if(inference_method == "hmc"){
    
    tryCatch({
      mod |> sample_safe(
        sample_fx,
        data = data_for_model ,
        chains = chains,
        parallel_chains = chains,
        threads_per_chain = ceiling(cores/chains),
        iter_warmup = warmup_samples,
        iter_sampling = as.integer(output_samples /chains),
        #refresh = ifelse(verbose, 1000, 0),
        seed = seed,
        save_warmup = FALSE,
        init = init,
        output_dir = output_directory,
        show_messages = verbose,
        ...
      ) |> 
        suppressWarnings()
      
    },
    error = function(e) {
      
      # I don't know why thi is needed nd why the model sometimes is not compliled correctly
      if(e |> as.character() |>  str_detect("Model not compiled"))
        model = load_model(model_name, force=TRUE, threads = cores)
      else 
        stop()   
      
    })
    
    
  } else{
    if(inference_method=="pathfinder") init = pf
    else if(inference_method=="variational") init = list(init_list)
    
    vb_iterative(
      mod,
      output_samples = output_samples ,
      iter = 10000,
      tol_rel_obj = 0.01,
      data = data_for_model, refresh = ifelse(verbose, 1000, 0),
      seed = seed,
      output_dir = output_directory,
      init = init,
      inference_method = inference_method, 
      cores = cores,
      psis_resample = FALSE, 
      verbose = verbose
    ) %>%
      suppressWarnings()
    
  }
  
  
  
}


#' @importFrom purrr map2_lgl
#' @importFrom tidyr pivot_wider
#' @importFrom rlang :=
#'
#' @keywords internal
#' @noRd
parse_fit = function(data_for_model, fit, censoring_iteration = 1, chains){
  
  # Define the variables as NULL to avoid CRAN NOTES
  M <- NULL
  
  
  fit %>%
    draws_to_tibble_x_y("beta", "C", "M") %>%
    left_join(tibble(C=seq_len(ncol(data_for_model$X)), C_name = colnames(data_for_model$X)), by = "C") %>%
    nest(!!as.symbol(sprintf("beta_posterior_%s", censoring_iteration)) := -M)
  
}

#' @importFrom purrr map2_lgl
#' @importFrom tidyr pivot_wider
#' @importFrom tidyr spread
#' @importFrom stats C
#' @importFrom rlang :=
#' @importFrom tibble enframe
#'
#' @keywords internal
#' @noRd
beta_to_CI = function(fitted, censoring_iteration = 1, false_positive_rate, factor_of_interest){
  
  # Define the variables as NULL to avoid CRAN NOTES
  M <- NULL
  C_name <- NULL
  .lower <- NULL
  .median <- NULL
  .upper <- NULL
  
  effect_column_name = sprintf("composition_effect_%s", factor_of_interest) %>% as.symbol()
  
  CI = fitted %>%
    unnest(!!as.symbol(sprintf("beta_posterior_%s", censoring_iteration))) %>%
    nest(data = -c(M, C, C_name)) %>%
    # Attach beta
    mutate(!!as.symbol(sprintf("beta_quantiles_%s", censoring_iteration)) := map(
      data,
      ~ quantile(
        .x$.value,
        probs = c(false_positive_rate/2,  0.5,  1-(false_positive_rate/2))
      ) %>%
        enframe() %>%
        mutate(name = c(".lower", ".median", ".upper")) %>%
        spread(name, value)
    )) %>%
    unnest(!!as.symbol(sprintf("beta_quantiles_%s", censoring_iteration))) %>%
    select(-data, -C) %>%
    pivot_wider(names_from = C_name, values_from=c(.lower , .median ,  .upper)) 
  
  # Create main effect if exists
  if(!is.na(factor_of_interest) )
    CI |>
    mutate(!!effect_column_name := !!as.symbol(sprintf(".median_%s", factor_of_interest))) %>%
    nest(composition_CI = -c(M, !!effect_column_name))
  
  else 
    CI |> nest(composition_CI = -c(M))
  
}

#' @importFrom purrr map2_lgl
#' @importFrom tidyr pivot_wider
#' @importFrom stats C
#' @importFrom rlang :=
#'
#' @keywords internal
#' @noRd
alpha_to_CI = function(fitted, censoring_iteration = 1, false_positive_rate, factor_of_interest){
  
  # Define the variables as NULL to avoid CRAN NOTES
  M <- NULL
  C_name <- NULL
  .lower <- NULL
  .median <- NULL
  .upper <- NULL
  
  effect_column_name = sprintf("variability_effect_%s", factor_of_interest) %>% as.symbol()
  
  fitted %>%
    unnest(!!as.symbol(sprintf("alpha_%s", censoring_iteration))) %>%
    nest(data = -c(M, C, C_name)) %>%
    # Attach beta
    mutate(!!as.symbol(sprintf("alpha_quantiles_%s", censoring_iteration)) := map(
      data,
      ~ quantile(
        .x$.value,
        probs = c(false_positive_rate/2,  0.5,  1-(false_positive_rate/2))
      ) %>%
        enframe() %>%
        mutate(name = c(".lower", ".median", ".upper")) %>%
        spread(name, value)
    )) %>%
    unnest(!!as.symbol(sprintf("alpha_quantiles_%s", censoring_iteration))) %>%
    select(-data, -C) %>%
    pivot_wider(names_from = C_name, values_from=c(.lower , .median ,  .upper)) %>%
    mutate(!!effect_column_name := !!as.symbol(sprintf(".median_%s", factor_of_interest))) %>%
    nest(variability_CI = -c(M, !!effect_column_name))
  
  
  
}


#' Get Random Intercept Design 2
#'
#' This function processes the formula composition elements in the data and creates design matrices
#' for random intercept models.
#'
#' @param .data_ A data frame containing the data.
#' @param .sample A quosure representing the sample variable.
#' @param formula_composition A data frame containing the formula composition elements.
#' 
#' @return A data frame with the processed design matrices for random intercept models.
#' 
#' @importFrom glue glue
#' @importFrom magrittr subtract
#' @importFrom purrr map2
#' @importFrom dplyr mutate
#' @importFrom dplyr select
#' @importFrom dplyr pull
#' @importFrom dplyr filter
#' @importFrom dplyr left_join
#' @importFrom dplyr mutate_all
#' @importFrom dplyr mutate_if
#' @importFrom dplyr as_tibble
#' @importFrom tidyr pivot_longer
#' @importFrom rlang enquo
#' @importFrom rlang quo_name
#' @importFrom tidyselect all_of
#' @importFrom readr type_convert
#' @noRd
get_random_effect_design2 = function(.data_, .sample, formula_composition ){
  
  # Define the variables as NULL to avoid CRAN NOTES
  formula <- NULL
  
  .sample = enquo(.sample)
  
  grouping_table =
    formula_composition |>
    formula_to_random_effect_formulae() |>
    
    mutate(design = map2(
      formula, grouping,
      ~ {
        
        mydesign = .data_ |> get_design_matrix(.x, !!.sample)
        
        mydesigncol_X_random_eff = .data_ |> select(all_of(.y)) |> pull(1) |> rep(ncol(mydesign)) |> matrix(ncol = ncol(mydesign))
        mydesigncol_X_random_eff[mydesign==0L] = NA
        colnames(mydesigncol_X_random_eff) = colnames(mydesign)
        rownames(mydesigncol_X_random_eff) = rownames(mydesign)
        
        mydesigncol_X_random_eff |>
          as_tibble(rownames = quo_name(.sample)) |>
          pivot_longer(-!!.sample, names_to = "factor", values_to = "grouping") |>
          filter(!is.na(grouping)) |>
          
          mutate("mean_idx" = glue("{factor}___{grouping}") |> as.factor() |> as.integer() )|>
          with_groups(factor, ~ ..1 |> mutate(mean_idx = if_else(mean_idx == max(mean_idx), 0L, mean_idx))) |>
          mutate(minus_sum = if_else(mean_idx==0, factor |> as.factor() |> as.integer(), 0L)) |>
          
          # Make right rank
          mutate(mean_idx = mean_idx |> as.factor() |> as.integer() |> subtract(1)) |>
          
          # drop minus_sum if we just have one grouping per factor
          with_groups(factor, ~ {
            if(length(unique(..1$grouping)) == 1) ..1 |> mutate(., minus_sum = 0)
            else ..1
          }) |>
          
          # Add value
          left_join(
            
            mydesign |>
              as.data.frame() |> 
              mutate_all(as.character) |> 
              readr::type_convert(guess_integer = TRUE ) |> 
              as_tibble(rownames = quo_name(.sample)) |>
              suppressMessages() |>
              mutate_if(is.integer, ~1) |>
              pivot_longer(-!!.sample, names_to = "factor"),
            
            by = join_by(!!.sample, factor)
          ) |>
          
          # Create unique name
          mutate(group___label = glue("{factor}___{grouping}")) |>
          mutate(group___numeric = group___label |> as.factor() |> as.integer()) |>
          mutate(factor___numeric = `factor` |> as.factor() |> as.integer())
        
        
        
      }))
  
}


#' Get Random Intercept Design
#'
#' This function processes random intercept elements in the data and creates design matrices
#' for random intercept models.
#'
#' @param .data_ A data frame containing the data.
#' @param .sample A quosure representing the sample variable.
#' @param random_effect_elements A data frame containing the random intercept elements.
#' 
#' @return A data frame with the processed design matrices for random intercept models.
#' 
#' @importFrom glue glue
#' @importFrom magrittr subtract
#' @importFrom dplyr select
#' @importFrom dplyr mutate
#' @importFrom dplyr pull
#' @importFrom dplyr if_else
#' @importFrom dplyr distinct
#' @importFrom dplyr group_by
#' @importFrom dplyr summarise
#' @importFrom rlang set_names
#' @importFrom purrr map_lgl
#' @importFrom purrr pmap
#' @importFrom purrr map_int
#' @importFrom dplyr with_groups
#' @importFrom rlang enquo
#' @importFrom rlang quo_name
#' @importFrom tidyselect all_of
#' @noRd
get_random_effect_design = function(.data_, .sample, random_effect_elements ){
  
  # Define the variables as NULL to avoid CRAN NOTES
  is_factor_continuous <- NULL
  design <- NULL
  max_mean_idx <- NULL
  max_minus_sum <- NULL
  max_factor_numeric <- NULL
  max_group_numeric <- NULL
  min_mean_idx <- NULL
  min_minus_sum <- NULL
  
  .sample = enquo(.sample)
  
  # If intercept is not defined create it
  if(nrow(random_effect_elements) == 0 )
    return(
      random_effect_elements |>
        mutate(
          design = list(),
          is_factor_continuous = logical()
        )
    )
  
  # Otherwise process
  random_effect_elements |>
    mutate(is_factor_continuous = map_lgl(
      `factor`,
      ~ .x != "(Intercept)" && .data_ |> select(all_of(.x)) |> pull(1) |> is("numeric")
    )) |>
    mutate(design = pmap(
      list(grouping, `factor`, is_factor_continuous),
      ~ {
        
        # Make exception for random intercept
        if(..2 == "(Intercept)")
          .data_ = .data_ |> mutate(`(Intercept)` = 1)
        
        .data_ =
          .data_ |>
          select(!!.sample, ..1, ..2) |>
          set_names(c(quo_name(.sample), "group___", "factor___")) |>
          mutate(group___numeric = group___, factor___numeric = factor___) |>
          
          mutate(group___label := glue("{group___}___{.y}")) |>
          mutate(factor___ = ..2)
        
        
        # If factor is continuous
        if(..3)
          .data_ %>%
          
          # Mutate random intercept grouping to number
          mutate(group___numeric = factor(group___numeric) |> as.integer()) |>
          
          # If intercept is not defined create it
          mutate(., factor___numeric = 1L) |>
          
          # If categorical make sure the group is independent for factors
          mutate(mean_idx = glue("{group___numeric}") |> as.factor() |> as.integer()) |>
          mutate(mean_idx = if_else(mean_idx == max(mean_idx), 0L, mean_idx)) |>
          mutate(mean_idx = as.factor(mean_idx) |> as.integer() |> subtract(1L)) |>
          mutate(minus_sum = if_else(mean_idx==0, 1L, 0L))
        
        #|>
        #  distinct()
        
        # If factor is discrete
        else
          .data_ %>%
          
          # Mutate random intercept grouping to number
          mutate(group___numeric = factor(group___numeric) |> as.integer()) |>
          
          # If categorical make sure the group is independent for factors
          mutate(mean_idx = glue("{factor___numeric}{group___numeric}") |> as.factor() |> as.integer()) |>
          with_groups(factor___numeric, ~ ..1 |> mutate(mean_idx = if_else(mean_idx == max(mean_idx), 0L, mean_idx))) |>
          mutate(mean_idx = as.factor(mean_idx) |> as.integer() |> subtract(1L)) |>
          mutate(minus_sum = if_else(mean_idx==0, as.factor(factor___numeric) |> as.integer(), 0L)) |>
          
          # drop minus_sum if we just have one group___numeric per factor
          with_groups(factor___numeric, ~ {
            if(length(unique(..1$group___numeric)) == 1) ..1 |> mutate(., minus_sum = 0)
            else ..1
          }) |>
          mutate(factor___numeric = as.factor(factor___numeric) |> as.integer())
        
        #|>
        #  distinct()
      }
    )) |>
    
    # Make indexes unique across parameters
    mutate(
      max_mean_idx = map_int(design, ~ ..1 |> pull(mean_idx) |> max()),
      max_minus_sum = map_int(design, ~ ..1 |> pull(minus_sum) |> max()),
      max_factor_numeric = map_int(design, ~ ..1 |> pull(factor___numeric) |> max()),
      max_group_numeric = map_int(design, ~ ..1 |> pull(group___numeric) |> max())
    ) |>
    mutate(
      min_mean_idx = cumsum(max_mean_idx) - max_mean_idx ,
      min_minus_sum = cumsum(max_minus_sum) - max_minus_sum,
      max_factor_numeric = cumsum(max_factor_numeric) - max_factor_numeric,
      max_group_numeric = cumsum(max_group_numeric) - max_group_numeric
    ) |>
    mutate(design = pmap(
      list(design, min_mean_idx, min_minus_sum, max_factor_numeric, max_group_numeric),
      ~ ..1 |>
        mutate(
          mean_idx = if_else(mean_idx>0, mean_idx + ..2, mean_idx),
          minus_sum = if_else(minus_sum>0, minus_sum + ..3, minus_sum),
          factor___numeric = factor___numeric + ..4,
          group___numeric = group___numeric + ..5
          
        )
    ))
  
}

#' @importFrom glue glue
#' @noRd
get_design_matrix = function(.data_spread, formula, .sample){
  
  .sample = enquo(.sample)
  
  design_matrix =
    .data_spread %>%
    
    select(!!.sample, parse_formula(formula)) |>
    mutate(across(where(is.numeric),  scale)) |>
    model.matrix(formula, data=_)
  
  rownames(design_matrix) = .data_spread |> pull(!!.sample)
  
  design_matrix
}


#' Check Random Intercept Design
#'
#' This function checks the validity of the random intercept design in the data.
#'
#' @param .data A data frame containing the data.
#' @param factor_names A character vector of factor names.
#' @param random_effect_elements A data frame containing the random intercept elements.
#' @param formula The formula used for the model.
#' @param X The design matrix.
#' 
#' @return A data frame with the checked random intercept elements.
#' 
#' @importFrom tidyr nest
#' @importFrom dplyr mutate
#' @importFrom dplyr select
#' @importFrom dplyr pull
#' @importFrom dplyr filter
#' @importFrom dplyr distinct
#' @importFrom rlang set_names
#' @importFrom tidyr unite
#' @importFrom purrr map2
#' @importFrom stringr str_subset
#' @importFrom readr type_convert
#' @noRd
check_random_effect_design = function(.data, factor_names, random_effect_elements, formula, X){
  
  # Define the variables as NULL to avoid CRAN NOTES
  factors <- NULL
  groupings <- NULL
  
  
  .data_ = .data
  
  # Loop across groupings
  random_effect_elements |>
    nest(factors = `factor` ) |>
    mutate(checked = map2(
      grouping, factors,
      ~ {
        
        .y = unlist(.y)
        
        # Check that the group column is categorical
        stopifnot("sccomp says: the grouping column should be categorical (not numeric)" =
                    .data_ |>
                    select(all_of(.x)) |>
                    pull(1) |>
                    class() %in%
                    c("factor", "logical", "character")
        )
        
        
        # # Check sanity of the grouping if only random intercept
        # stopifnot(
        #   "sccomp says: the random intercept completely confounded with one or more discrete factors" =
        #     !(
        #       !.y |> equals("(Intercept)") &&
        #         .data_ |> select(any_of(.y)) |> suppressWarnings() |>  pull(1) |> class() %in% c("factor", "character") |> any() &&
        #         .data_ |>
        #         select(.x, any_of(.y)) |>
        #         select_if(\(x) is.character(x) | is.factor(x) | is.logical(x)) |>
        #         distinct() %>%
        #
        #         # TEMPORARY FIX
        #         set_names(c(colnames(.)[1], 'factor___temp')) |>
        #
        #         count(factor___temp) |>
        #         pull(n) |>
        #         equals(1) |>
        #         any()
        #     )
        # )
        
        # # Check if random intercept with random continuous slope. At the moment is not possible
        # # Because it would require I believe a multivariate prior
        # stopifnot(
        #   "sccomp says: continuous random slope is not supported yet" =
        #     !(
        #       .y |> str_subset("1", negate = TRUE) |> length() |> gt(0) &&
        #         .data_ |>
        #         select(
        #           .y |> str_subset("1", negate = TRUE)
        #         ) |>
        #         map_chr(class) %in%
        #         c("integer", "numeric")
        #     )
        # )
        
        # Check if random intercept with random continuous slope. At the moment is not possible
        # Because it would require I believe a multivariate prior
        stopifnot(
          "sccomp says: currently, discrete random slope is only supported in a intercept-free model. For example ~ 0 + treatment + (treatment | group)" =
            !(
              # If I have both random intercept and random discrete slope
              
              .y |> equals("(Intercept)") |> any() &&
                length(.y) > 1 &&
                # If I have random slope and non-intercept-free model
                .data_ |> select(any_of(.y)) |> suppressWarnings() |>  pull(1) |> class() %in% c("factor", "character") |> any()
              
            )
        )
        
        
        # I HAVE TO REVESIT THIS
        #  stopifnot(
        #   "sccomp says: the groups in the formula (factor | group) should not be shared across factor groups" =
        #     !(
        #       # If I duplicated groups
        #       .y  |> identical("(Intercept)") |> not() &&
        #       .data_ |> select(.y |> setdiff("(Intercept)")) |> lapply(class) != "numeric" &&
        #         .data_ |>
        #         select(.x, .y |> setdiff("(Intercept)")) |>
        #
        #         # Drop the factor represented by the intercept if any
        #         mutate(`parameter` = .y |> setdiff("(Intercept)")) |>
        #         unite("factor_name", c(parameter, factor), sep = "", remove = FALSE) |>
        #         filter(factor_name %in% colnames(X)) |>
        #
        #         # Count
        #         distinct() %>%
        #         set_names(as.character(1:ncol(.))) |>
        #         count(`1`) |>
        #         filter(n>1) |>
        #         nrow() |>
        #         gt(1)
        #
        #     )
        # )
        
      }
    ))
  
  random_effect_elements |>
    nest(groupings = grouping ) |>
    mutate(checked = map2(`factor`, groupings, ~{
      # Check the same group spans multiple factors
      stopifnot(
        "sccomp says: the groups in the formula (factor | group) should be present in only one factor, including the intercept" =
          !(
            # If I duplicated groups
            .y |> unlist() |> length() |> gt(1)
            
          )
      )
      
      
    }))
  
  
  
  
}

#' @importFrom purrr when
#' @importFrom stats model.matrix
#' @importFrom tidyr expand_grid
#' @importFrom stringr str_detect
#' @importFrom stringr str_remove_all
#' @importFrom purrr reduce
#' @importFrom stats as.formula
#'
#' @keywords internal
#' @noRd
#'
data_spread_to_model_input =
  function(
    .data_spread, formula, .sample, .cell_type, .count,
    truncation_ajustment = 1, approximate_posterior_inference ,
    formula_variability = ~ 1,
    contrasts = NULL,
    bimodal_mean_variability_association = FALSE,
    use_data = TRUE,
    random_effect_elements){
    
    # Define the variables as NULL to avoid CRAN NOTES
    exposure <- NULL
    design <- NULL
    mat <- NULL
    factor___numeric <- NULL
    mean_idx <- NULL
    design_matrix <- NULL
    minus_sum <- NULL
    group___numeric <- NULL
    idx <- NULL
    group___label <- NULL
    parameter <- NULL
    group <- NULL
    design_matrix_col <- NULL
    
    # Prepare column same enquo
    .sample = enquo(.sample)
    .cell_type = enquo(.cell_type)
    .count = enquo(.count)
    .grouping_for_random_effect =
      random_effect_elements |>
      pull(grouping) |>
      unique() 
    
    if (length(.grouping_for_random_effect)==0 ) .grouping_for_random_effect = "random_effect"
    
    
    X  =
      
      .data_spread |>
      get_design_matrix(
        # Drop random intercept
        formula |>
          as.character() |>
          str_remove_all("\\+ ?\\(.+\\|.+\\)") |>
          paste(collapse="") |>
          as.formula(),
        !!.sample
      )
    
    Xa  =
      .data_spread |>
      get_design_matrix(
        # Drop random intercept
        formula_variability |>
          as.character() |>
          str_remove_all("\\+ ?\\(.+\\|.+\\)") |>
          paste(collapse="") |>
          as.formula() ,
        !!.sample
      )
    
    XA = Xa %>%
      as_tibble() %>%
      distinct()
    
    A = ncol(XA);
    Ar = nrow(XA);
    
    factor_names = parse_formula(formula)
    factor_names_variability = parse_formula(formula_variability)
    cell_cluster_names = .data_spread %>% select(-!!.sample, -any_of(factor_names), -exposure, -!!.grouping_for_random_effect) %>% colnames()
    
    # Random intercept
    if(nrow(random_effect_elements)>0 ) {
      
      #check_random_effect_design(.data_spread, any_of(factor_names), random_effect_elements, formula, X)
      random_effect_grouping = get_random_effect_design2(.data_spread, !!.sample,  formula )
      
      # Actual parameters, excluding for the sum to one parameters
      is_random_effect = 1
      
      random_effect_grouping =
        random_effect_grouping |>
        mutate(design_matrix = map(
          design,
          ~ ..1 |>
            select(!!.sample, group___label, value) |>
            pivot_wider(names_from = group___label, values_from = value) |>
            mutate(across(everything(), ~ .x |> replace_na(0)))
        )) 
      
      
      X_random_effect = 
        random_effect_grouping |> 
        pull(design_matrix) |> 
        _[[1]]  |>  
        as_matrix(rownames = quo_name(.sample))
      
      # For now that stan does not have tuples, I just allow max two levels
      if(random_effect_grouping |> nrow() > 2) stop("sccomp says: at the moment sccomp allow max two groupings")
      # This will be modularised with the new stan
      if(random_effect_grouping |> nrow() > 1)
        X_random_effect_2 =   
        random_effect_grouping |> 
        pull(design_matrix) |> 
        _[[2]] |>  
        as_matrix(rownames = quo_name(.sample))
      else X_random_effect_2 =  X_random_effect[,0,drop=FALSE]
      
      n_random_eff = random_effect_grouping |> nrow()
      
      ncol_X_random_eff =
        random_effect_grouping |>
        mutate(n = map_int(design, ~.x |> distinct(group___numeric) |> nrow())) |>
        pull(n) 
      
      if(ncol_X_random_eff |> length() < 2) ncol_X_random_eff[2] = 0
      
      # TEMPORARY
      group_factor_indexes_for_covariance = 
        X_random_effect |> 
        colnames() |> 
        enframe(value = "parameter", name = "order")  |> 
        separate(parameter, c("factor", "group"), "___", remove = FALSE) |> 
        complete(factor, group, fill = list(order=0)) |> 
        select(-parameter) |> 
        pivot_wider(names_from = group, values_from = order)  |> 
        as_matrix(rownames = "factor")
      
      n_groups = group_factor_indexes_for_covariance |> ncol()
      
      # This will be modularised with the new stan
      if(random_effect_grouping |> nrow() > 1)
        group_factor_indexes_for_covariance_2 = 
        X_random_effect_2 |> 
        colnames() |> 
        enframe(value = "parameter", name = "order")  |> 
        separate(parameter, c("factor", "group"), "___", remove = FALSE) |> 
        complete(factor, group, fill = list(order=0)) |> 
        select(-parameter) |> 
        pivot_wider(names_from = group, values_from = order)  |> 
        as_matrix(rownames = "factor")
      else group_factor_indexes_for_covariance_2 = matrix()[0,0, drop=FALSE]
      
      n_groups = n_groups |> c(group_factor_indexes_for_covariance_2 |> ncol())
      
      how_many_factors_in_random_design = list(group_factor_indexes_for_covariance, group_factor_indexes_for_covariance_2) |> map_int(nrow)
      
      
    } else {
      X_random_effect = matrix(rep(1, nrow(.data_spread)))[,0, drop=FALSE]
      X_random_effect_2 = matrix(rep(1, nrow(.data_spread)))[,0, drop=FALSE] # This will be modularised with the new stan
      is_random_effect = 0
      ncol_X_random_eff = c(0,0)
      n_random_eff = 0
      n_groups = c(0,0)
      how_many_factors_in_random_design = c(0,0)
      group_factor_indexes_for_covariance = matrix()[0,0, drop=FALSE]
      group_factor_indexes_for_covariance_2 = matrix()[0,0, drop=FALSE] # This will be modularised with the new stan
    }
    
    
    y = .data_spread %>% select(-any_of(factor_names), -exposure, -!!.grouping_for_random_effect) %>% as_matrix(rownames = quo_name(.sample))
    
    # If proportion ix 0 issue
    is_proportion = y |> as.numeric() |> max()  |> between(0,1) |> all()
    if(is_proportion){
      y_proportion = y
      y = y[0,,drop = FALSE]
    }
    else{
      y = y
      y_proportion = y[0,,drop = FALSE]
    }
    
    data_for_model =
      list(
        N = .data_spread %>% nrow(),
        M = .data_spread %>% select(-!!.sample, -any_of(factor_names), -exposure, -!!.grouping_for_random_effect) %>% ncol(),
        exposure = .data_spread$exposure,
        is_proportion = is_proportion,
        y = y,
        y_proportion = y_proportion,
        X = X,
        XA = XA,
        Xa = Xa,
        C = ncol(X),
        A = A,
        Ar = Ar,
        truncation_ajustment = truncation_ajustment,
        is_vb = as.integer(approximate_posterior_inference),
        bimodal_mean_variability_association = bimodal_mean_variability_association,
        use_data = use_data,
        
        # Random intercept
        is_random_effect = is_random_effect,
        ncol_X_random_eff = ncol_X_random_eff,
        n_random_eff = n_random_eff,
        n_groups  = n_groups,
        X_random_effect = X_random_effect,
        X_random_effect_2 = X_random_effect_2,
        group_factor_indexes_for_covariance = group_factor_indexes_for_covariance,
        group_factor_indexes_for_covariance_2 = group_factor_indexes_for_covariance_2,
        how_many_factors_in_random_design = how_many_factors_in_random_design,
        
        # For parallel chains
        grainsize = 1,
        
        ## LOO
        enable_loo = FALSE
      )
    
    # Add censoring
    data_for_model$is_truncated = 0
    data_for_model$truncation_up = matrix(rep(-1, data_for_model$M * data_for_model$N), ncol = data_for_model$M)
    data_for_model$truncation_down = matrix(rep(-1, data_for_model$M * data_for_model$N), ncol = data_for_model$M)
    data_for_model$truncation_not_idx = seq_len(data_for_model$M*data_for_model$N)
    data_for_model$TNS = length(data_for_model$truncation_not_idx)
    data_for_model$truncation_not_idx_minimal = matrix(c(1,1), nrow = 1)[0,,drop=FALSE]
    data_for_model$TNIM = 0
    
    # Add parameter factor dictionary
    data_for_model$factor_parameter_dictionary = tibble()
    
    if(.data_spread  |> select(any_of(parse_formula(formula))) |> lapply(class) %in% c("factor", "character") |> any())
      data_for_model$factor_parameter_dictionary =
      data_for_model$factor_parameter_dictionary |> bind_rows(
        # For discrete
        .data_spread  |>
          select(any_of(parse_formula(formula)))  |>
          distinct()  |>
          
          # Drop numerical
          select_if(function(x) !is.numeric(x)) |>
          pivot_longer(everything(), names_to =  "factor", values_to = "parameter") %>%
          unite("design_matrix_col", c(`factor`, parameter), sep="", remove = FALSE)  |>
          select(-parameter) |>
          filter(design_matrix_col %in% colnames(data_for_model$X)) %>%
          distinct()
        
      )
    
    # For continuous
    if(.data_spread  |> select(all_of(parse_formula(formula))) |> lapply(class) |> equals("numeric") |> any())
      data_for_model$factor_parameter_dictionary =
      data_for_model$factor_parameter_dictionary |>
      bind_rows(
        tibble(
          design_matrix_col =  .data_spread  |>
            select(all_of(parse_formula(formula)))  |>
            distinct()  |>
            
            # Drop numerical
            select_if(function(x) is.numeric(x)) |>
            names()
        ) |>
          mutate(`factor` = design_matrix_col)
      )
    
    # If constrasts is set it is a bit more complicated
    if(! is.null(contrasts))
      data_for_model$factor_parameter_dictionary =
      data_for_model$factor_parameter_dictionary |>
      distinct() |>
      expand_grid(parameter=contrasts) |>
      filter(str_detect(parameter, design_matrix_col )) |>
      select(-design_matrix_col) |>
      rename(design_matrix_col = parameter) |>
      distinct()
    
    data_for_model$intercept_in_design = X[,1] |> unique() |> identical(1)
    
    
    if (data_for_model$intercept_in_design | length(factor_names_variability) == 0) {
      data_for_model$A_intercept_columns = 1
    } else {
      data_for_model$A_intercept_columns = 
        .data_spread |> 
        select(any_of(factor_names[1])) |> 
        distinct() |> 
        nrow()
    }
    
    
    if (data_for_model$intercept_in_design ) {
      data_for_model$B_intercept_columns = 1
    } else {
      data_for_model$B_intercept_columns = 
        .data_spread |> 
        select(any_of(factor_names[1])) |> 
        distinct() |> 
        nrow()
    }
    
    # Return
    data_for_model
  }

data_to_spread = function(.data, formula, .sample, .cell_type, .count, .grouping_for_random_effect){
  
  .sample = enquo(.sample)
  .cell_type = enquo(.cell_type)
  .count = enquo(.count)
  .grouping_for_random_effect = .grouping_for_random_effect |> map(~ .x |> quo_name() ) |> unlist()
  
  is_proportion = .data |> pull(!!.count) |> max() <= 1
  
  .data = 
    .data |>
    nest(data = -!!.sample) 
  
  # If proportions exposure = 1
  if(is_proportion) .data = .data |> mutate(exposure = 1)
  else
    .data = 
    .data |>
    mutate(exposure = map_int(data, ~ .x |> pull(!!.count) |> sum() )) 
  
  .data_to_spread = 
    .data |>
    unnest(data) |>
    select(!!.sample, !!.cell_type, exposure, !!.count, parse_formula(formula), any_of(.grouping_for_random_effect)) 
  
  # Check if duplicated samples
  if(
    .data_to_spread |> distinct(!!.sample, !!.cell_type) |> nrow() <
    .data_to_spread |> nrow()
  ) stop("sccomp says: You have duplicated .sample IDs in your input dataset. A .sample .cell_group combination must be unique")
  
  .data_to_spread |>
    spread(!!.cell_type, !!.count)
  
  
}

#' @importFrom purrr when
#' @importFrom stats model.matrix
#'
#' @keywords internal
#' @noRd
#'
data_simulation_to_model_input =
  function(.data, formula, .sample, .cell_type, .exposure, .coefficients, truncation_ajustment = 1, approximate_posterior_inference ){
    
    # Define the variables as NULL to avoid CRAN NOTES
    sd <- NULL
    . <- NULL
    
    
    # Prepare column same enquo
    .sample = enquo(.sample)
    .cell_type = enquo(.cell_type)
    .exposure = enquo(.exposure)
    .coefficients = enquo(.coefficients)
    
    factor_names = parse_formula(formula)
    
    sample_data =
      .data %>%
      select(!!.sample, any_of(factor_names)) %>%
      distinct() %>%
      arrange(!!.sample)
    X =
      sample_data %>%
      model.matrix(formula, data=.) %>%
      apply(2, function(x) {
        
        if(sd(x)==0 ) x
        else x |> scale(scale=FALSE)
        
      } ) %>%
      {
        .x = (.)
        rownames(.x) = sample_data %>% pull(!!.sample)
        .x
      }
    
    if(factor_names == "1") XA = X[,1, drop=FALSE]
    else XA = X[,c(1,2), drop=FALSE]
    
    XA = XA |> 
      as_tibble()  |> 
      distinct()
    
    cell_cluster_names =
      .data %>%
      distinct(!!.cell_type) %>%
      arrange(!!.cell_type) %>%
      pull(!!.cell_type)
    
    coefficients =
      .data %>%
      select(!!.cell_type, !!.coefficients) %>%
      unnest(!!.coefficients) %>%
      distinct() %>%
      arrange(!!.cell_type) %>%
      as_matrix(rownames = quo_name(.cell_type)) %>%
      t()
    
    list(
      N = .data %>% distinct(!!.sample) %>% nrow(),
      M = .data %>% distinct(!!.cell_type) %>% nrow(),
      exposure = .data %>% distinct(!!.sample, !!.exposure) %>% arrange(!!.sample) %>% pull(!!.exposure),
      X = X,
      XA = XA,
      C = ncol(X),
      A =  ncol(XA),
      beta = coefficients
    )
    
  }


#' Choose the number of chains baed on how many draws we need from the posterior distribution
#' Because there is a fix cost (warmup) to starting a new chain,
#' we need to use the minimum amount that we can parallelise
#' @param how_many_posterior_draws A real number of posterior draws needed
#' @param max_number_to_check A sane upper plateau
#'
#' @keywords internal
#' @noRd
#'
#' @return A Stan fit object
find_optimal_number_of_chains = function(how_many_posterior_draws = 100,
                                         max_number_to_check = 100, warmup = 200, parallelisation_start_penalty = 100) {
  
  
  
  # Define the variables as NULL to avoid CRAN NOTES
  chains <- NULL
  
  
  chains_df =
    tibble(chains = seq_len(max_number_to_check)) %>%
    mutate(tot = (how_many_posterior_draws / chains) + warmup + (parallelisation_start_penalty * chains))
  
  d1 <- diff(chains_df$tot) / diff(seq_len(nrow(chains_df))) # first derivative
  abs(d1) %>% order() %>% .[1] # Find derivative == 0
  
  
}


get.elbow.points.indices <- function(x, y, threshold) {
  # From https://stackoverflow.com/questions/41518870/finding-the-elbow-knee-in-a-curve
  d1 <- diff(y) / diff(x) # first derivative
  d2 <- diff(d1) / diff(x[-1]) # second derivative
  indices <- which(abs(d2) > threshold)
  return(indices)
}

#' @importFrom magrittr divide_by
#' @importFrom magrittr multiply_by
#' @importFrom stats C
#'
#' @keywords internal
#' @noRd
#'
get_probability_non_zero_OLD = function(.data, prefix = "", test_above_logit_fold_change = 0){
  
  # Define the variables as NULL to avoid CRAN NOTES
  .draw <- NULL
  M <- NULL
  C_name <- NULL
  bigger_zero <- NULL
  smaller_zero <- NULL
  
  probability_column_name = sprintf("%s_prob_H0", prefix) %>% as.symbol()
  
  total_draws = .data %>% pull(2) %>% .[[1]] %>% distinct(.draw) %>% nrow()
  
  .data %>%
    unnest(2 ) %>%
    filter(C ==2) %>%
    nest(data = -c(M, C_name)) %>%
    mutate(
      bigger_zero = map_int(data, ~ .x %>% filter(.value>test_above_logit_fold_change) %>% nrow),
      smaller_zero = map_int(data, ~ .x %>% filter(.value< -test_above_logit_fold_change) %>% nrow)
    ) %>%
    rowwise() %>%
    mutate(
      !!probability_column_name :=
        1 - (
          max(bigger_zero, smaller_zero) %>%
            #max(1) %>%
            divide_by(total_draws)
        )
    )  %>%
    ungroup() %>%
    select(M, !!probability_column_name)
  # %>%
  # mutate(false_discovery_rate = cummean(prob_non_zero))
  
}

#' @importFrom magrittr divide_by
#' @importFrom magrittr multiply_by
#' @importFrom stats C
#' @importFrom stats setNames
#'
#' @keywords internal
#' @noRd
#'
get_probability_non_zero_ = function(fit, parameter, prefix = "", test_above_logit_fold_change = 0){
  
  # Define the variables as NULL to avoid CRAN NOTES
  M <- NULL
  C_name <- NULL
  bigger_zero <- NULL
  smaller_zero <- NULL
  
  
  
  draws = fit$draws(
    variables = parameter,
    inc_warmup = FALSE,
    format = getOption("cmdstanr_draws_format", "draws_matrix")
  )
  
  total_draws = dim(draws)[1]
  
  bigger_zero =
    draws %>%
    apply(2, function(x) (x>test_above_logit_fold_change) %>% which %>% length)
  
  smaller_zero =
    draws %>%
    apply(2, function(x) (x< -test_above_logit_fold_change) %>% which %>% length)
  
  (1 - (pmax(bigger_zero, smaller_zero) / total_draws)) %>%
    enframe() %>%
    tidyr::extract(name, c("C", "M"), ".+\\[([0-9]+),([0-9]+)\\]") %>%
    mutate(across(c(C, M), ~ as.integer(.x))) %>%
    tidyr::spread(C, value)
  
}

get_probability_non_zero = function(draws, test_above_logit_fold_change = 0, probability_column_name){
  
  draws %>%
    with_groups(c(M, C_name), ~ .x |> summarise(
      bigger_zero = which(.value>test_above_logit_fold_change) |> length(),
      smaller_zero = which(.value< -test_above_logit_fold_change) |> length(),
      n=n()
    )) |>
    mutate(!!as.symbol(probability_column_name) :=  (1 - (pmax(bigger_zero, smaller_zero) / n)))
  
}

#' @keywords internal
#' @noRd
#'
parse_generated_quantities = function(rng, number_of_draws = 1){
  
  # Define the variables as NULL to avoid CRAN NOTES
  .draw <- NULL
  N <- NULL
  .value <- NULL
  generated_counts <- NULL
  M <- NULL
  generated_proportions <- NULL
  
  draws_to_tibble_x_y(rng, "counts", "N", "M", number_of_draws) %>%
    with_groups(c(.draw, N), ~ .x %>% mutate(generated_proportions = .value/max(1, sum(.value)))) %>%
    filter(.draw<= number_of_draws) %>%
    rename(generated_counts = .value, replicate = .draw) %>%
    
    mutate(generated_counts = as.integer(generated_counts)) %>%
    select(M, N, generated_proportions, generated_counts, replicate)
  
}

#' design_matrix_and_coefficients_to_simulation
#'
#' @description Create simulation from design matrix and coefficient matrix
#'
#' @importFrom dplyr left_join
#' @importFrom tidyr expand_grid
#'
#' @keywords internal
#' @noRd
#'
#' @param design_matrix A matrix
#' @param coefficient_matrix A matrix
#'
#' @return A data frame
#'
#'
#'
design_matrix_and_coefficients_to_simulation = function(
    
  design_matrix, coefficient_matrix, .estimate_object
  
){
  
  # Define the variables as NULL to avoid CRAN NOTES
  cell_type <- NULL
  beta_1 <- NULL
  beta_2 <- NULL
  
  design_df = as.data.frame(design_matrix)
  coefficient_df = as.data.frame(coefficient_matrix)
  
  rownames(design_df) = sprintf("sample_%s", seq_len(nrow(design_df)))
  colnames(design_df) = sprintf("factor_%s", seq_len(ncol(design_df)))
  
  rownames(coefficient_df) = sprintf("cell_type_%s", seq_len(nrow(coefficient_df)))
  colnames(coefficient_df) = sprintf("beta_%s", seq_len(ncol(coefficient_df)))
  
  input_data =
    expand_grid(
      sample = rownames(design_df),
      cell_type = rownames(coefficient_df)
    ) |>
    left_join(design_df |> as_tibble(rownames = "sample") , by = "sample") |>
    left_join(coefficient_df |>as_tibble(rownames = "cell_type"), by = "cell_type")
  
  simulate_data(.data = input_data,
                
                .estimate_object = .estimate_object,
                
                formula_composition = ~ factor_1 ,
                .sample = sample,
                .cell_group = cell_type,
                .coefficients = c(beta_1, beta_2),
                mcmc_seed = sample(1e5, 1)
  )
  
  
}

#' @importFrom rlang ensym
#' @noRd
class_list_to_counts = function(.data, .sample, .cell_group){
  
  .sample_for_tidyr = ensym(.sample)
  .cell_group_for_tidyr = ensym(.cell_group)
  
  .sample = enquo(.sample)
  .cell_group = enquo(.cell_group)
  
  .data %>%
    count(!!.sample,
          !!.cell_group,
          name = "count") %>%
    
    complete(
      !!.sample_for_tidyr,!!.cell_group_for_tidyr,
      fill = list(count = 0)
    ) %>%
    mutate(count = as.integer(count))
}

#' @importFrom dplyr cummean
#' @noRd
get_FDR = function(x){
  
  # Define the variables as NULL to avoid CRAN NOTES
  value <- NULL
  name <- NULL
  FDR <- NULL
  
  
  enframe(x) %>%
    arrange(value) %>%
    mutate(FDR = cummean(value)) %>%
    arrange(name) %>%
    pull(FDR)
}

#' Plot 1D Intervals for Cell-group Effects
#'
#' This function creates a series of 1D interval plots for cell-group effects, highlighting significant differences based on a given significance threshold.
#'
#' @param .data Data frame containing the main data.
#' @param significance_threshold Numeric value specifying the significance threshold for highlighting differences. Default is 0.025.
#' @param test_composition_above_logit_fold_change A positive integer. It is the effect threshold used for the hypothesis test. A value of 0.2 correspond to a change in cell proportion of 10% for a cell type with baseline proportion of 50%. That is, a cell type goes from 45% to 50%. When the baseline proportion is closer to 0 or 1 this effect thrshold has consistent value in the logit uncontrained scale.
#' @importFrom patchwork wrap_plots
#' @importFrom forcats fct_reorder
#' @importFrom tidyr drop_na
#' 
#' @export
#' 
#' @return A combined plot of 1D interval plots.
#' @examples
#' # Example usage:
#' # plot_1D_intervals(.data, "cell_group", 0.025, theme_minimal())
plot_1D_intervals = function(.data, significance_threshold = 0.05, test_composition_above_logit_fold_change = .data |> attr("test_composition_above_logit_fold_change")){
  
  # Define the variables as NULL to avoid CRAN NOTES
  parameter <- NULL
  estimate <- NULL
  value <- NULL
  
  .cell_group = attr(.data, ".cell_group")
  
  # Check if test have been done
  if(.data |> select(ends_with("FDR")) |> ncol() |> equals(0))
    stop("sccomp says: to produce plots, you need to run the function sccomp_test() on your estimates.")
  
  
  plot_list = 
    .data |>
    filter(parameter != "(Intercept)") |>
    
    # Reshape data
    select(-contains("n_eff"), -contains("R_k_hat"), -contains("_ess"), -contains("_rhat")) |> 
    pivot_longer(c(contains("c_"), contains("v_")), names_sep = "_", names_to = c("which", "estimate")) |>
    pivot_wider(names_from = estimate, values_from = value) |>
    
    # Nest data by parameter and which
    nest(data = -c(parameter, which)) |>
    mutate(plot = pmap(
      list(data, which, parameter),
      ~  {
        
        # Check if there are any statistics to plot
        if(..1 |> filter(!effect |> is.na()) |> nrow() |> equals(0))
          return(NA)
        
        # Create ggplot for each nested data
        ggplot(..1, aes(x = effect, y = fct_reorder(!!.cell_group, effect))) +
          geom_vline(xintercept = test_composition_above_logit_fold_change, colour = "grey") +
          geom_vline(xintercept = -test_composition_above_logit_fold_change, colour = "grey") +
          geom_errorbar(aes(xmin = lower, xmax = upper, color = FDR < significance_threshold)) +
          geom_point() +
          scale_color_brewer(palette = "Set1") +
          xlab("Credible interval of the slope") +
          ylab("Cell group") +
          ggtitle(sprintf("%s %s", ..2, ..3)) +
          multipanel_theme +
          theme(legend.position = "bottom") 
      }
    )) %>%
    
    # Filter out NA plots
    filter(!plot |> is.na()) |> 
    pull(plot) 
  
  # Combine all individual plots into one plot
  plot_list |>
    wrap_plots(ncol = plot_list |> length() |> sqrt() |> ceiling())
}


#' Plot 2D Intervals for Mean-Variance Association
#'
#' This function creates a 2D interval plot for mean-variance association, highlighting significant differences based on a given significance threshold.
#'
#' @param .data Data frame containing the main data.
#' @param significance_threshold Numeric value specifying the significance threshold for highlighting differences. Default is 0.025.
#' @param test_composition_above_logit_fold_change A positive integer. It is the effect threshold used for the hypothesis test. A value of 0.2 correspond to a change in cell proportion of 10% for a cell type with baseline proportion of 50%. That is, a cell type goes from 45% to 50%. When the baseline proportion is closer to 0 or 1 this effect thrshold has consistent value in the logit uncontrained scale.
#' 
#' 
#' @importFrom dplyr filter arrange mutate if_else row_number
#' @importFrom ggplot2 ggplot geom_vline geom_hline geom_errorbar geom_point annotate aes facet_wrap
#' @importFrom ggrepel geom_text_repel
#' @importFrom scales trans_new
#' @importFrom stringr str_replace
#' @importFrom stats quantile
#' @importFrom magrittr equals
#' 
#' @export
#' 
#' @return A ggplot object representing the 2D interval plot.
#' @examples
#' # Example usage:
#' # plot_2D_intervals(.data, "cell_group", theme_minimal(), 0.025)
plot_2D_intervals = function(
    .data, 
    significance_threshold = 0.05, 
    test_composition_above_logit_fold_change = 
      .data |> attr("test_composition_above_logit_fold_change")
){
  
  # Define the variables as NULL to avoid CRAN NOTES
  v_effect <- NULL
  parameter <- NULL
  c_effect <- NULL
  c_lower <- NULL
  c_upper <- NULL
  c_FDR <- NULL
  v_lower <- NULL
  v_upper <- NULL
  v_FDR <- NULL
  cell_type_label <- NULL
  
  
  .cell_group = attr(.data, ".cell_group")
  
  # Check if test have been done
  if(.data |> select(ends_with("FDR")) |> ncol() |> equals(0))
    stop("sccomp says: to produce plots, you need to run the function sccomp_test() on your estimates.")
  
  
  # Mean-variance association
  .data %>%
    
    # Filter where variance is inferred
    filter(!is.na(v_effect)) %>%
    
    # Add labels for significant cell groups
    with_groups(
      parameter,
      ~ .x %>%
        arrange(c_FDR) %>%
        mutate(cell_type_label = if_else(row_number() <= 3 & c_FDR < significance_threshold & parameter != "(Intercept)", !!.cell_group, ""))
    ) %>%
    with_groups(
      parameter,
      ~ .x %>%
        arrange(v_FDR) %>%
        mutate(cell_type_label = if_else((row_number() <= 3 & v_FDR < significance_threshold & parameter != "(Intercept)"), !!.cell_group, cell_type_label))
    ) %>%
    
    {
      .x = (.)
      
      # Plot
      ggplot(.x, aes(c_effect, v_effect)) +
        
        # Add vertical and horizontal lines
        geom_vline(xintercept = c(-test_composition_above_logit_fold_change, test_composition_above_logit_fold_change), colour = "grey", linetype = "dashed", linewidth = 0.3) +
        geom_hline(yintercept = c(-test_composition_above_logit_fold_change, test_composition_above_logit_fold_change), colour = "grey", linetype = "dashed", linewidth = 0.3) +
        
        # Add error bars
        geom_errorbar(aes(xmin = `c_lower`, xmax = `c_upper`, color = `c_FDR` < significance_threshold, alpha = `c_FDR` < significance_threshold), linewidth = 0.2) +
        geom_errorbar(aes(ymin = v_lower, ymax = v_upper, color = `v_FDR` < significance_threshold, alpha = `v_FDR` < significance_threshold), linewidth = 0.2) +
        
        # Add points
        geom_point(size = 0.2) +
        
        # Add annotations
        annotate("text", x = 0, y = 3.5, label = "Variability", size = 2) +
        annotate("text", x = 5, y = 0, label = "Abundance", size = 2, angle = 270) +
        
        # Add text labels for significant cell groups
        geom_text_repel(aes(c_effect, -v_effect, label = cell_type_label), size = 2.5, data = .x %>% filter(cell_type_label != "")) +
        
        # Set color and alpha scales
        scale_color_manual(values = c("#D3D3D3", "#E41A1C")) +
        scale_alpha_manual(values = c(0.4, 1)) +
        
        # Facet by parameter
        facet_wrap(~parameter, scales = "free") +
        
        # Apply custom theme
        multipanel_theme 
    }
}


#' Plot Boxplot of Cell-group Proportion
#'
#' This function creates a boxplot of cell-group proportions, optionally highlighting significant differences based on a given significance threshold.
#'
#' @param .data Data frame containing the main data.
#' @param data_proportion Data frame containing proportions of cell groups.
#' @param factor_of_interest Character string specifying the biological condition of interest.
#' @param .cell_group Character string specifying the cell group to be analyzed.
#' @param .sample Character string specifying the sample identifier.
#' @param significance_threshold Numeric value specifying the significance threshold for highlighting differences. Default is 0.05.
#' @param my_theme A ggplot2 theme object to be applied to the plot.
#' @param remove_unwanted_effects Logical value indicating whether to remove unwanted effects. Default is FALSE.
#' 
#' @importFrom scales trans_new
#' @importFrom stringr str_replace
#' @importFrom stats quantile
#'
#' @return A ggplot object representing the boxplot.
#' 
#' @examples
#' # Example usage:
#' # plot_boxplot(
#' #   .data, data_proportion, factor_of_interest = "condition",
#' #   .cell_group = "cell_group", .sample = "sample",
#' #   significance_threshold = 0.05, my_theme = theme_minimal(),
#' #   remove_unwanted_effects = FALSE
#' # )
#' 
#' @noRd
plot_boxplot = function(
    .data, data_proportion, factor_of_interest, .cell_group,
    .sample, significance_threshold = 0.05, my_theme, remove_unwanted_effects = FALSE
){
  
  # Define the variables as NULL to avoid CRAN NOTES
  stats_name <- NULL
  parameter <- NULL
  stats_value <- NULL
  count_data <- NULL
  generated_proportions <- NULL
  proportion <- NULL
  name <- NULL
  outlier <- NULL
  
  # Function to calculate boxplot statistics
  calc_boxplot_stat <- function(x) {
    coef <- 1.5
    n <- sum(!is.na(x))
    
    # Calculate quantiles
    stats <- quantile(x, probs = c(0.0, 0.25, 0.5, 0.75, 1.0))
    names(stats) <- c("ymin", "lower", "middle", "upper", "ymax")
    iqr <- diff(stats[c(2, 4)])
    
    # Set whiskers
    outliers <- x < (stats[2] - coef * iqr) | x > (stats[4] + coef * iqr)
    if (any(outliers)) {
      stats[c(1, 5)] <- range(c(stats[2:4], x[!outliers]), na.rm = TRUE)
    }
    return(stats)
  }
  
  # Function to remove leading zero from labels
  dropLeadingZero <- function(l){  stringr::str_replace(l, '0(?=.)', '') }
  
  # Define square root transformation and its inverse
  S_sqrt <- function(x){sign(x)*sqrt(abs(x))}
  IS_sqrt <- function(x){x^2*sign(x)}
  S_sqrt_trans <- function() scales::trans_new("S_sqrt",S_sqrt,IS_sqrt)
  
  .cell_group = enquo(.cell_group)
  .sample = enquo(.sample)
  
  # Prepare significance colors
  significance_colors =
    .data %>%
    pivot_longer(
      c(contains("c_"), contains("v_")),
      names_pattern = "([cv])_([a-zA-Z0-9]+)",
      names_to = c("which", "stats_name"),
      values_to = "stats_value"
    ) %>%
    filter(stats_name == "FDR") %>%
    filter(parameter != "(Intercept)") %>%
    filter(stats_value < significance_threshold) %>%
    filter(`factor` == factor_of_interest) 
  
  if(nrow(significance_colors) > 0){
    
    if(.data |> attr("contrasts") |> is.null())
      significance_colors =
        significance_colors %>%
        unite("name", c(which, parameter), remove = FALSE) %>%
        distinct() %>%
        
        # Get clean parameter
        mutate(!!as.symbol(factor_of_interest) := str_replace(parameter, sprintf("^%s", `factor`), "")) %>%
        with_groups(c(!!.cell_group, !!as.symbol(factor_of_interest)), ~ .x %>% summarise(name = paste(name, collapse = ", ")))
    else
      significance_colors =
        significance_colors |>
        mutate(count_data = map(count_data, ~ .x |> select(all_of(factor_of_interest)) |> distinct())) |>
        unnest(count_data) |>
        
        # Filter relevant parameters
        mutate( !!as.symbol(factor_of_interest) := as.character(!!as.symbol(factor_of_interest) ) ) |>
        filter(str_detect(parameter, !!as.symbol(factor_of_interest) )) |>
        
        # Rename
        select(!!.cell_group, !!as.symbol(factor_of_interest), name = parameter) |>
        
        # Merge contrasts
        with_groups(c(!!.cell_group, !!as.symbol(factor_of_interest)), ~ .x %>% summarise(name = paste(name, collapse = ", ")))
  }
  
  my_boxplot = ggplot()
  
  if("fit" %in% names(attributes(.data))){
    
    # Remove unwanted effects?
    if(remove_unwanted_effects) formula_composition = as.formula("~ " |> paste(factor_of_interest))
    else formula_composition = NULL
    
      simulated_proportion =
        .data |>
        sccomp_replicate(formula_composition = formula_composition, number_of_draws = 100) |>
        left_join(data_proportion %>% distinct(!!as.symbol(factor_of_interest), !!.sample, !!.cell_group))
    
    my_boxplot = my_boxplot +
      
      # Add boxplot for simulated proportions
      stat_summary(
        aes(!!as.symbol(factor_of_interest), (generated_proportions)),
        fun.data = calc_boxplot_stat, geom="boxplot",
        outlier.shape = NA, outlier.color = NA,outlier.size = 0,
        fatten = 0.5, lwd=0.2,
        data =
          simulated_proportion %>%
          inner_join(data_proportion %>% distinct(!!as.symbol(factor_of_interest), !!.cell_group)),
        color="blue"
      )
  }
  
  if(nrow(significance_colors) == 0 ||
     length(intersect(
       significance_colors |> pull(!!as.symbol(factor_of_interest)),
       data_proportion |> pull(!!as.symbol(factor_of_interest))
     )) == 0){
    
    my_boxplot=
      my_boxplot +
      
      # Add boxplot without significance colors
      geom_boxplot(
        aes(!!as.symbol(factor_of_interest), proportion,  group=!!as.symbol(factor_of_interest), fill = NULL),
        outlier.shape = NA, outlier.color = NA,outlier.size = 0,
        data =
          data_proportion |>
          mutate(!!as.symbol(factor_of_interest) := as.character(!!as.symbol(factor_of_interest))) ,
        fatten = 0.5,
        lwd=0.5,
      )
  } else {
    my_boxplot=
      my_boxplot +
      
      # Add boxplot with significance colors
      geom_boxplot(
        aes(!!as.symbol(factor_of_interest), proportion,  group=!!as.symbol(factor_of_interest), fill = name),
        outlier.shape = NA, outlier.color = NA,outlier.size = 0,
        data =
          data_proportion |>
          mutate(!!as.symbol(factor_of_interest) := as.character(!!as.symbol(factor_of_interest))) %>%
          left_join(significance_colors, by = c(quo_name(.cell_group), factor_of_interest)),
        fatten = 0.5,
        lwd=0.5,
      )
  }
  
  my_boxplot +
    
    # Add jittered points for individual data
    geom_jitter(
      aes(!!as.symbol(factor_of_interest), proportion, shape=outlier, color=outlier,  group=!!as.symbol(factor_of_interest)),
      data = data_proportion,
      position=position_jitterdodge(jitter.height = 0, jitter.width = 0.2),
      size = 0.5
    ) +
    
    # Facet wrap by cell group
    facet_wrap(
      vars(!!.cell_group),
      scales = "free_y",
      nrow = 4
    ) +
    scale_color_manual(values = c("black", "#e11f28")) +
    scale_y_continuous(trans=S_sqrt_trans(), labels = dropLeadingZero) +
    scale_fill_discrete(na.value = "white") +
    xlab("Biological condition") +
    ylab("Cell-group proportion") +
    guides(color="none", alpha="none", size="none") +
    labs(fill="Significant difference") +
    ggtitle("Note: Be careful judging significance (or outliers) visually for lowly abundant cell groups. \nVisualising proportion hides the uncertainty characteristic of count data, that a count-based statistical model can estimate.") +
    my_theme +
    theme(axis.text.x =  element_text(angle=20, hjust = 1), title = element_text(size = 3))
}

#' Plot Scatterplot of Cell-group Proportion
#'
#' This function creates a scatterplot of cell-group proportions, optionally highlighting significant differences based on a given significance threshold.
#'
#' @param .data Data frame containing the main data.
#' @param data_proportion Data frame containing proportions of cell groups.
#' @param factor_of_interest A factor indicating the biological condition of interest.
#' @param .cell_group The cell group to be analysed.
#' @param .sample The sample identifier.
#' @param significance_threshold Numeric value specifying the significance threshold for highlighting differences. Default is 0.025.
#' @param my_theme A ggplot2 theme object to be applied to the plot.
#' @importFrom scales trans_new
#' @importFrom stringr str_replace
#' @importFrom stats quantile
#' @importFrom magrittr equals
#' 
#' 
#' @return A ggplot object representing the scatterplot.
#' @examples
#' # Example usage:
#' # plot_scatterplot(.data, data_proportion, "condition", "cell_group", "sample", 0.025, theme_minimal())
plot_scatterplot = function(
    .data, data_proportion, factor_of_interest, .cell_group,
    .sample, significance_threshold = 0.05, my_theme
){
  
  # Define the variables as NULL to avoid CRAN NOTES
  stats_name <- NULL
  parameter <- NULL
  stats_value <- NULL
  count_data <- NULL
  generated_proportions <- NULL
  proportion <- NULL
  name <- NULL
  outlier <- NULL
  
  # Function to remove leading zero from labels
  dropLeadingZero <- function(l){  stringr::str_replace(l, '0(?=.)', '') }
  
  # Define square root transformation and its inverse
  S_sqrt <- function(x){sign(x)*sqrt(abs(x))}
  IS_sqrt <- function(x){x^2*sign(x)}
  S_sqrt_trans <- function() scales::trans_new("S_sqrt",S_sqrt,IS_sqrt)
  
  .cell_group = enquo(.cell_group)
  .sample = enquo(.sample)
  
  # Prepare significance colors
  significance_colors =
    .data %>%
    pivot_longer(
      c(contains("c_"), contains("v_")),
      names_pattern = "([cv])_([a-zA-Z0-9]+)",
      names_to = c("which", "stats_name"),
      values_to = "stats_value"
    ) %>%
    filter(stats_name == "FDR") %>%
    filter(parameter != "(Intercept)") %>%
    filter(stats_value < significance_threshold) %>%
    filter(`factor` == factor_of_interest) 
  
  if(nrow(significance_colors) > 0){
    
    if(.data |> attr("contrasts") |> is.null())
      significance_colors =
        significance_colors %>%
        unite("name", c(which, parameter), remove = FALSE) %>%
        distinct() %>%
        
        # Get clean parameter
        mutate(!!as.symbol(factor_of_interest) := str_replace(parameter, sprintf("^%s", `factor`), "")) %>%
        with_groups(c(!!.cell_group, !!as.symbol(factor_of_interest)), ~ .x %>% summarise(name = paste(name, collapse = ", ")))
    else
      significance_colors =
        significance_colors |>
        mutate(count_data = map(count_data, ~ .x |> select(all_of(factor_of_interest)) |> distinct())) |>
        unnest(count_data) |>
        
        # Filter relevant parameters
        mutate( !!as.symbol(factor_of_interest) := as.character(!!as.symbol(factor_of_interest) ) ) |>
        filter(str_detect(parameter, !!as.symbol(factor_of_interest) )) |>
        
        # Rename
        select(!!.cell_group, !!as.symbol(factor_of_interest), name = parameter) |>
        
        # Merge contrasts
        with_groups(c(!!.cell_group, !!as.symbol(factor_of_interest)), ~ .x %>% summarise(name = paste(name, collapse = ", ")))
  }
  
  my_scatterplot = ggplot()
  
  if("fit" %in% names(attributes(.data))){
    
    simulated_proportion =
      .data |>
      sccomp_replicate(number_of_draws = 1000) |>
      left_join(data_proportion %>% distinct(!!as.symbol(factor_of_interest), !!.sample, !!.cell_group))
    
    my_scatterplot = 
      my_scatterplot +
      
      # Add smoothed line for simulated proportions
      geom_smooth(
        aes(!!as.symbol(factor_of_interest), (generated_proportions)),
        lwd=0.2,
        data =
          simulated_proportion %>%
          inner_join(data_proportion %>% distinct(!!as.symbol(factor_of_interest), !!.cell_group, !!.sample)) ,
        color="blue", fill="blue",
        span = 1
      )
  }
  
  if(
    nrow(significance_colors)==0 ||
    
    significance_colors |> 
    pull(!!as.symbol(factor_of_interest)) |> 
    intersect(
      data_proportion |> 
      pull(!!as.symbol(factor_of_interest))
    ) |> 
    length() |> 
    equals(0)
  ) {
    
    my_scatterplot=
      my_scatterplot +
      
      # Add smoothed line without significance colors
      geom_smooth(
        aes(!!as.symbol(factor_of_interest), proportion, fill = NULL),
        data =
          data_proportion ,
        lwd=0.5,
        color = "black",
        span = 1
      )
  } else {
    my_scatterplot=
      my_scatterplot +
      
      # Add smoothed line with significance colors
      geom_smooth(
        aes(!!as.symbol(factor_of_interest), proportion, fill = name),
        outlier.shape = NA, outlier.color = NA,outlier.size = 0,
        data = data_proportion ,
        fatten = 0.5,
        lwd=0.5,
        color = "black",
        span = 1
      )
  }
  
  my_scatterplot +
    
    # Add jittered points for individual data
    geom_point(
      aes(!!as.symbol(factor_of_interest), proportion, shape=outlier, color=outlier),
      data = data_proportion,
      position=position_jitterdodge(jitter.height = 0, jitter.width = 0.2),
      size = 0.5
    ) +
    
    # Facet wrap by cell group
    facet_wrap(
      vars(!!.cell_group),
      scales = "free_y",
      nrow = 4
    ) +
    scale_color_manual(values = c("black", "#e11f28")) +
    scale_y_continuous(trans=S_sqrt_trans(), labels = dropLeadingZero) +
    scale_fill_discrete(na.value = "white") +
    xlab("Biological condition") +
    ylab("Cell-group proportion") +
    guides(color="none", alpha="none", size="none") +
    labs(fill="Significant difference") +
    ggtitle("Note: Be careful judging significance (or outliers) visually for lowly abundant cell groups. \nVisualising proportion hides the uncertainty characteristic of count data, that a count-based statistical model can estimate.") +
    my_theme +
    theme(axis.text.x =  element_text(angle=20, hjust = 1), title = element_text(size = 3))
}

draws_to_statistics = function(draws, false_positive_rate, test_composition_above_logit_fold_change, .cell_group, prefix = ""){
  
  # Define the variables as NULL to avoid CRAN NOTES
  M <- NULL
  parameter <- NULL
  bigger_zero <- NULL
  smaller_zero <- NULL
  lower <- NULL
  effect <- NULL
  upper <- NULL
  pH0 <- NULL
  FDR <- NULL
  n_eff <- NULL
  R_k_hat <- NULL
  
  .cell_group = enquo(.cell_group)
  
  draws =
    draws %>%
    group_by(!!.cell_group, M, parameter, rhat, ess_bulk, ess_tail) %>%
    summarise(
      lower = quantile(.value, false_positive_rate / 2),
      effect = quantile(.value, 0.5),
      upper = quantile(.value, 1 - (false_positive_rate / 2)),
      bigger_zero = sum(.value > test_composition_above_logit_fold_change),
      smaller_zero = sum(.value < -test_composition_above_logit_fold_change),
      # R_k_hat = unique(R_k_hat),
      # n_eff = unique(n_eff),
      n = n(),
      .groups = "drop"  # To ungroup the output if needed
    ) |> 
    
    # Calculate probability non 0
    mutate(pH0 =  (1 - (pmax(bigger_zero, smaller_zero) / n))) |>
    with_groups(parameter, ~ mutate(.x, FDR = get_FDR(pH0))) |>
    
    select(!!.cell_group, M, parameter, lower, effect, upper, pH0, FDR, any_of(c("n_eff", "R_k_hat", "rhat", "ess_bulk", "ess_tail"))) |>
    suppressWarnings()
  
  # Setting up names separately because |> is not flexible enough
  draws |>
    setNames(c(colnames(draws)[1:3], sprintf("%s%s", prefix, colnames(draws)[4:ncol(draws)])))
}

enquos_from_list_of_symbols <- function(...) {
  enquos(...)
}

contrasts_to_enquos = function(contrasts){
  
  # Define the variables as NULL to avoid CRAN NOTES
  . <- NULL
  
  contrasts |> enquo() |> quo_names() |> syms() %>% do.call(enquos_from_list_of_symbols, .)
}


contrasts_to_parameter_list = function(contrasts, drop_back_quotes = TRUE){
  
  if(is.null(names(contrasts)))
    names(contrasts) <- contrasts
  
  contrast_list <-
    contrasts |>
    
    # Remove fractions used in multiplication before or after '*'
    str_remove_all_ignoring_if_inside_backquotes("([0-9]+/[0-9]+ ?\\* ?)|(\\* ?[0-9]+/[0-9]+)") |>
    
    # Remove decimals used in multiplication before or after '*'
    str_remove_all_ignoring_if_inside_backquotes("([-+]?[0-9]+\\.[0-9]+ ?\\* ?)|(\\* ?[-+]?[0-9]+\\.[0-9]+)") |>
    
    # Remove fractions used in divisions
    str_remove_all_ignoring_if_inside_backquotes("/ ?[0-9]+") |>
    
    # Remove standalone numerical constants not inside backquotes
    str_remove_all_ignoring_if_inside_backquotes("\\b[0-9]+\\b") |>
    
    # Split by "+", "-", "*", "/"
    str_split_ignoring_if_inside_backquotes("\\+|-|\\*|/") |>
    unlist() |>
    
    # Remove parentheses and spaces
    str_remove_all_ignoring_if_inside_backquotes("[\\(\\) ]") |>
    
    # Remove empty strings
    {\(x) x[x != ""]}()
  
  if(drop_back_quotes)
    contrast_list <-
    contrast_list |>
    str_remove_all("`")
  
  contrast_list |> unique()
}



#' Mutate Data Frame Based on Expression List
#'
#' @description
#' `mutate_from_expr_list` takes a data frame and a list of formula expressions, 
#' and mutates the data frame based on these expressions. It allows for ignoring 
#' errors during the mutation process.
#'
#' @param x A data frame to be mutated.
#' @param formula_expr A named list of formula expressions used for mutation.
#' @param ignore_errors Logical flag indicating whether to ignore errors during mutation.
#'
#' @return A mutated data frame with added or modified columns based on `formula_expr`.
#'
#' @details
#' The function performs various checks and transformations on the formula expressions,
#' ensuring that the specified transformations are valid and can be applied to the data frame.
#' It supports advanced features like handling special characters in column names and intelligent
#' parsing of formulas.
#'
#' @importFrom purrr map2_dfc
#' @importFrom tibble add_column
#' @importFrom tidyselect last_col
#' @importFrom dplyr mutate
#' @importFrom stringr str_subset
#' 
#' @noRd
#' 
mutate_from_expr_list = function(x, formula_expr, ignore_errors = TRUE){
  
  if(formula_expr |> names() |> is.null())
    names(formula_expr) = formula_expr

  # Creating a named vector where the names are the strings to be replaced
  # and the values are empty strings
  contrasts_elements = contrasts_to_parameter_list(formula_expr, drop_back_quotes = FALSE)

  # Check if all elements of contrasts are in the parameter
  parameter_names = x |> colnames()
  
  # Check is backquoted are not used
  require_back_quotes = !contrasts_elements |>  str_remove_all("`") |> contains_only_valid_chars_for_column() 
  has_left_back_quotes = contrasts_elements |>  str_detect("^`") 
  has_right_back_quotes = contrasts_elements |>  str_detect("`$") 
  if_true_not_good = require_back_quotes & !(has_left_back_quotes & has_right_back_quotes)
  
  if(any(if_true_not_good))
    warning(sprintf("sccomp says: for columns which have special characters e.g. %s, you need to use surrounding backquotes ``.", paste(contrasts_elements[!if_true_not_good], sep=", ")))
  
  # Check if columns exist
  contrasts_not_in_the_model = 
    contrasts_elements |> 
    str_remove_all("`") |> 
    setdiff(parameter_names)
  
  contrasts_not_in_the_model = contrasts_not_in_the_model[contrasts_not_in_the_model!=""]
  
  if(length(contrasts_not_in_the_model) > 0 & !ignore_errors)
    warning(sprintf("sccomp says: These components of your contrasts are not present in the model as parameters: %s. Factors including special characters, e.g. \"(Intercept)\" require backquotes e.g. \"`(Intercept)`\" ", paste(contrasts_not_in_the_model, sep = ", ")))
  
  # Calculate
  if(ignore_errors) my_mutate = mutate_ignore_error
  else my_mutate = mutate
  
  map2_dfc(
    formula_expr,
    names(formula_expr),
    ~  x |>
      my_mutate(!!.y := eval(rlang::parse_expr(.x))) |>
      # mutate(!!column_name := eval(rlang::parse_expr(.x))) |>
      select(any_of(.y))
  )  |>
    
    # I could drop this to just result contrasts
    add_column(x |> select(M, .chain, .iteration, .draw), .before = 1)
  
}

mutate_ignore_error = function(x, ...){
  tryCatch(
    {  x |> mutate(...) },
    error=function(cond) {  x  }
  )
}

simulate_multinomial_logit_linear = function(model_input, sd = 0.51){
  
  mu = model_input$X %*% model_input$beta
  
  proportions =
    rnorm(length(mu), mu, sd) %>%
    matrix(nrow = nrow(model_input$X)) %>%
    boot::inv.logit()
  apply(1, function(x) x/sum(x)) %>%
    t()
  
  rownames(proportions) = rownames(model_input$X)
  colnames(proportions) = colnames(model_input$beta )
}

compress_zero_one = function(y){
  # https://stats.stackexchange.com/questions/48028/beta-regression-of-proportion-data-including-1-and-0
  
  n = length(y)
  (y * (n-1) + 0.5) / n
}

# this can be helpful if we want to draw PCA with uncertainty
get_abundance_contrast_draws = function(.data, contrasts){
  
  # Define the variables as NULL to avoid CRAN NOTES
  X <- NULL
  .value <- NULL
  X_random_effect <- NULL
  .variable <- NULL
  y <- NULL
  M <- NULL
  khat <- NULL
  parameter <- NULL
  n_eff <- NULL
  R_k_hat <- NULL
  
  
  .cell_group = .data |>  attr(".cell_group")
  
  # Beta
  beta_factor_of_interest = .data |> attr("model_input") %$% X |> colnames()
  # beta =
  #   .data |>
  #   attr("fit") %>%
  #   draws_to_tibble_x_y("beta", "C", "M") |>
  #   pivot_wider(names_from = C, values_from = .value) %>%
  #   setNames(colnames(.)[1:5] |> c(beta_factor_of_interest))
  
  if(contrasts |> is.null())
    draws = 
    .data |>
    attr("fit") %>%
    draws_to_tibble_x_y("beta", "C", "M") |> 
    pivot_wider(names_from = C, values_from = .value) %>%
    setNames(colnames(.)[1:5] |> c(beta_factor_of_interest))
  
  else if(
    (beta_factor_of_interest %in% contrasts_to_parameter_list(contrasts)) |> which() |> length() > 0
  )
    
  draws =
    .data |>
    attr("fit") %>%
    draws_to_tibble_x_y("beta", "C", "M") |> 
    left_join(
      beta_factor_of_interest |> enframe(name = "C", value = "parameters_name"),
      by = "C"
    )  |> 
    filter(parameters_name %in% contrasts_to_parameter_list(contrasts)) |> 
    select(-C) |> 
    pivot_wider(names_from = parameters_name, values_from = .value)
  
  else 
    draws = tibble()
  

  # Abundance
  draws = draws |> select(-.variable)
  

  # Random effect
  
  beta_random_effect_factor_of_interest = .data |> attr("model_input") %$% X_random_effect |> colnames()
  
  if(
    .data |> attr("model_input") %$% n_random_eff > 0 &&
    (
      contrasts |> is.null() || 
    (beta_random_effect_factor_of_interest %in% contrasts_to_parameter_list(contrasts)) |> which() |> length() > 0
    )  
 ){
    
    
    beta_random_effect =
      .data |>
      attr("fit") %>%
      draws_to_tibble_x_y("random_effect", "C", "M") 
    
    # Add last component
    other_group_random_effect = 
      beta_random_effect |> 
      with_groups(c(C, .chain, .iteration, .draw, .variable ), ~ .x |> summarise(.value = sum(.value))) |> 
      mutate(.value = -.value, M = beta_random_effect |> pull(M) |> max() + 1)
    
    # I HAVE TO REGULARISE THE LAST COMPONENT
    mean_of_the_sd_of_the_point_estimates = 
      beta_random_effect |> 
      group_by(M, C) |> 
      summarise(point_estimate = mean(.value)) |> 
      group_by(M) |> 
      summarise(sd_of_point_estimates = sd(point_estimate)) |> 
      pull(sd_of_point_estimates) |> 
      mean()
    
    other_sd_of_the_point_estimates = 
      other_group_random_effect |> 
      group_by(M, C) |> 
      summarise(point_estimate = mean(.value)) |> 
      group_by(M) |> 
      summarise(sd_of_point_estimates = sd(point_estimate)) |> 
      pull(sd_of_point_estimates)
    
    other_group_random_effect = 
      other_group_random_effect |> 
      mutate(.value = .value / (other_sd_of_the_point_estimates / mean_of_the_sd_of_the_point_estimates))
    
    
    beta_random_effect = 
      beta_random_effect |> 
      bind_rows( other_group_random_effect )
    
    # mutate(is_treg = cell_type =="treg") |>
    #   nest(data = -is_treg) |>
    #   mutate(data = map2(
    #     data, is_treg,
    #     ~ {
    #       if(.y) .x |> mutate(c_effect = c_effect/5 )
    #       else(.x)
    #     }
    #   )) |>
    #   unnest(data) |>
    
    
    # Reshape
    # Speed up if I have contrasts
    if(!contrasts |> is.null())
      beta_random_effect = 
        beta_random_effect |> 
        left_join(
          beta_random_effect_factor_of_interest |> enframe(name = "C", value = "parameters_name"),
          by = "C"
        )  |> 
        filter(parameters_name %in% contrasts_to_parameter_list(contrasts)) |> 
        select(-C) |> 
        pivot_wider(names_from = parameters_name, values_from = .value)
    
    else
      beta_random_effect = 
        beta_random_effect |>
        pivot_wider(names_from = C, values_from = .value) %>%
        setNames(colnames(.)[1:5] |> c(beta_random_effect_factor_of_interest))
    
    # If I don't have fix nor 1st level random effect
    if(draws |> nrow() == 0)
      draws = select(beta_random_effect, -.variable)
    else 
      draws = draws |> 
      left_join(select(beta_random_effect, -.variable),
                by = c("M", ".chain", ".iteration", ".draw")
      )
    
  }  else {
    beta_random_effect_factor_of_interest = ""
  }
  
  # Second random effect. IN THE FUTURE THIS WILL BE VECTORISED TO ARBUTRARY GRI+OUING
  beta_random_effect_factor_of_interest_2 = .data |> attr("model_input") %$% X_random_effect_2 |> colnames()
  
  if(
    .data |> attr("model_input") %$% n_random_eff > 1 &&
    (
      contrasts |> is.null() || 
      (beta_random_effect_factor_of_interest_2 %in% contrasts_to_parameter_list(contrasts)) |> which() |> length() > 0
    )
  ){
    
    beta_random_effect_2 =
      .data |>
      attr("fit") %>%
      draws_to_tibble_x_y("random_effect_2", "C", "M") 
    
    # Add last component
    other_group_random_effect = 
      beta_random_effect_2 |> 
      with_groups(c(C, .chain, .iteration, .draw, .variable ), ~ .x |> summarise(.value = sum(.value))) |> 
      mutate(.value = -.value, M = beta_random_effect_2 |> pull(M) |> max() + 1)
    
    # I HAVE TO REGULARISE THE LAST COMPONENT
    mean_of_the_sd_of_the_point_estimates = 
      beta_random_effect_2 |> 
      group_by(M, C) |> 
      summarise(point_estimate = mean(.value)) |> 
      group_by(M) |> 
      summarise(sd_of_point_estimates = sd(point_estimate)) |> 
      pull(sd_of_point_estimates) |> 
      mean()
    
    other_sd_of_the_point_estimates = 
      other_group_random_effect |> 
      group_by(M, C) |> 
      summarise(point_estimate = mean(.value)) |> 
      group_by(M) |> 
      summarise(sd_of_point_estimates = sd(point_estimate)) |> 
      pull(sd_of_point_estimates)
    
    other_group_random_effect = 
      other_group_random_effect |> 
      mutate(.value = .value / (other_sd_of_the_point_estimates / mean_of_the_sd_of_the_point_estimates))
    
    
    beta_random_effect_2 = 
      beta_random_effect_2 |> 
      bind_rows( other_group_random_effect )
    
    # Reshape
    # Speed up if I have contrasts
    if(!contrasts |> is.null())
      beta_random_effect_2 = 
      beta_random_effect_2 |> 
      left_join(
        beta_random_effect_factor_of_interest_2 |> enframe(name = "C", value = "parameters_name"),
        by = "C"
      )  |> 
      filter(parameters_name %in% contrasts_to_parameter_list(contrasts)) |> 
      select(-C) |> 
      pivot_wider(names_from = parameters_name, values_from = .value)
    
    else
      beta_random_effect_2 = 
      beta_random_effect_2 |>
      pivot_wider(names_from = C, values_from = .value) %>%
      setNames(colnames(.)[1:5] |> c(beta_random_effect_factor_of_interest_2))
    
    # If I don't have fix nor 1st level random effect
    if(draws |> nrow() == 0)
      draws = select(beta_random_effect_2, -.variable)
    else 
    draws = draws |> 
      left_join(select(beta_random_effect_2, -.variable),
                by = c("M", ".chain", ".iteration", ".draw")
      )
  } else {
    beta_random_effect_factor_of_interest_2 = ""
  }
  
  
  # If I have constrasts calculate
  if(!is.null(contrasts))
    draws = 
    draws |> 
    mutate_from_expr_list(contrasts, ignore_errors = FALSE) |>
    select(- any_of(c(beta_factor_of_interest, beta_random_effect_factor_of_interest) |> setdiff(contrasts)) ) 
  
  # Add cell name
  draws = draws |> 
    left_join(
      .data |>
        attr("model_input") %$%
        y %>%
        colnames() |>
        enframe(name = "M", value  = quo_name(.cell_group)),
      by = "M"
    ) %>%
    select(!!.cell_group, everything())
  
  
  # If no contrasts of interest just return an empty data frame
  if(ncol(draws)==5) return(draws |> distinct(M, !!.cell_group))
  
  # Get convergence
  convergence_df =
    .data |>
    attr("fit") |>
    summary_to_tibble("beta", "C", "M") |>
    
    # Add cell name
    left_join(
      .data |>
        attr("model_input") %$%
        y %>%
        colnames() |>
        enframe(name = "M", value  = quo_name(.cell_group)),
      by = "M"
    ) |>
    
    # factor names
    left_join(
      beta_factor_of_interest |>
        enframe(name = "C", value = "parameter"),
      by = "C"
    )
  
  # if ("Rhat" %in% colnames(convergence_df)) {
  #   convergence_df <- rename(convergence_df, R_k_hat = Rhat)
  # } else if ("khat" %in% colnames(convergence_df)) {
  #   convergence_df <- rename(convergence_df, R_k_hat = khat)
  # }
  
  
  convergence_df =
    convergence_df |> 
    select(!!.cell_group, parameter, any_of(c("n_eff", "R_k_hat", "rhat", "ess_bulk", "ess_tail"))) |>
    suppressWarnings()
  
  draws |>
    pivot_longer(-c(1:5), names_to = "parameter", values_to = ".value") |>
    
    # Attach convergence if I have no contrasts
    left_join(convergence_df, by = c(quo_name(.cell_group), "parameter")) |>
    
    # Reorder because pivot long is bad
    mutate(parameter = parameter |> fct_relevel(colnames(draws)[-c(1:5)])) |>
    arrange(parameter)
  
}

#' @importFrom forcats fct_relevel
#' @noRd
get_variability_contrast_draws = function(.data, contrasts){
  
  # Define the variables as NULL to avoid CRAN NOTES
  XA <- NULL
  .value <- NULL
  y <- NULL
  M <- NULL
  khat <- NULL
  parameter <- NULL
  n_eff <- NULL
  R_k_hat <- NULL
  
  .cell_group = .data |>  attr(".cell_group")
  
  variability_factor_of_interest = .data |> attr("model_input") %$% XA |> colnames()
  
  draws =
    
    .data |>
    attr("fit") %>%
    draws_to_tibble_x_y("alpha_normalised", "C", "M") |>
    
    # We want variability, not concentration
    mutate(.value = -.value) 
  
  # Reshape
  # Speed up if I have contrasts
  if(!contrasts |> is.null())
    draws = 
    draws |> 
    left_join(
      variability_factor_of_interest |> enframe(name = "C", value = "parameters_name"),
      by = "C"
    )  |> 
    filter(parameters_name %in% contrasts_to_parameter_list(contrasts)) |> 
    select(-C) |> 
    pivot_wider(names_from = parameters_name, values_from = .value) |> 
    select( -.variable) 
  
  else
    draws =
    draws |>  
    pivot_wider(names_from = C, values_from = .value) %>%
    setNames(colnames(.)[1:5] |> c(variability_factor_of_interest)) |>
    select( -.variable) 
  
  # If I have constrasts calculate
  if (!is.null(contrasts)) 
    draws <- mutate_from_expr_list(draws, contrasts, ignore_errors = TRUE)
  
  draws =  draws |>
    
    # Add cell name
    left_join(
      .data |> attr("model_input") %$%
        y %>%
        colnames() |>
        enframe(name = "M", value  = quo_name(.cell_group)),
      by = "M"
    ) %>%
    select(!!.cell_group, everything())
  
  # If no contrasts of interest just return an empty data frame
  if(ncol(draws)==5) return(draws |> distinct(M, !!.cell_group))
  
  # Get convergence
  convergence_df =
    .data |>
    attr("fit") |>
    summary_to_tibble("alpha_normalised", "C", "M") |>
    
    # Add cell name
    left_join(
      .data |>
        attr("model_input") %$%
        y %>%
        colnames() |>
        enframe(name = "M", value  = quo_name(.cell_group)),
      by = "M"
    ) |>
    
    # factor names
    left_join(
      variability_factor_of_interest |>
        enframe(name = "C", value = "parameter"),
      by = "C"
    )
  
  convergence_df =
    convergence_df |> 
    select(!!.cell_group, parameter, any_of(c("n_eff", "R_k_hat", "rhat", "ess_bulk", "ess_tail"))) |>
    suppressWarnings()
  
  
  draws |>
    pivot_longer(-c(1:5), names_to = "parameter", values_to = ".value") |>
    
    # Attach convergence if I have no contrasts
    left_join(convergence_df, by = c(quo_name(.cell_group), "parameter")) |>
    
    # Reorder because pivot long is bad
    mutate(parameter = parameter |> fct_relevel(colnames(draws)[-c(1:5)])) |>
    arrange(parameter)
  
}

#' @importFrom tibble deframe
#'
#' @noRd
replicate_data = function(.data,
                          formula_composition = NULL,
                          formula_variability = NULL,
                          new_data = NULL,
                          number_of_draws = 1,
                          mcmc_seed = sample(1e5, 1),
                          cores = detectCores()){
  
  
  # Select model based on noise model
  noise_model = attr(.data, "noise_model")
  
  model_input = attr(.data, "model_input")
  .sample = attr(.data, ".sample")
  .cell_group = attr(.data, ".cell_group")
  .count = attr(.data, ".count")
  
  # Composition
  if(is.null(formula_composition)) formula_composition =  .data |> attr("formula_composition")
  
  # New data
  if(new_data |> is.null())
    new_data =
    .data |>
    select(count_data) |>
    unnest(count_data) |>
    distinct()
  
  # If seurat
  else if(new_data |> is("Seurat")) new_data = new_data[[]]
  
  # Just subset
  new_data = new_data |> .subset(!!.sample)
  
  
  # Check if the input new data is not suitable
  if(!parse_formula(formula_composition) %in% colnames(new_data) |> all())
    stop("sccomp says: your `new_data` might be malformed. It might have the covariate columns with multiple values for some element of the \"%s\" column. As a generic example, a sample identifier (\"Sample_123\") might be associated with multiple treatment values, or age values.")
  
  
  # Match factors with old data
  nrow_new_data = nrow(new_data)
  new_exposure = new_data |>
    nest(data = -!!.sample) |>
    mutate(exposure = map_dbl(
      data,
      ~{
        if (quo_name(.count) %in% colnames(.x)) .x |> pull(!!.count) |> sum()
        else 5000
      })) |>
    select(!!.sample, exposure) |>
    deframe() |>
    as.array()
  
  # Update data, merge with old data because
  # I need the same ordering of the design matrix
  old_data = 
    .data |>
    select(count_data) |>
    unnest(count_data) |>
    select(-!!.count) |>
    select(new_data |> as_tibble() |> colnames() |>  any_of()) |>
    distinct() |>
    
    # Change sample names to make unique
    mutate(dummy = "OLD") |>
    tidyr::unite(!!.sample, c(!!.sample, dummy), sep="___")
    
  # Harmonise factors
  new_data = new_data |> as_tibble() |> harmonise_factor_levels(old_data)
  
  # Check if some values were not present in the original data
  if(
    new_data |> nrow() > 
    new_data |> select(-!!.sample) |> drop_na() |> nrow()
  )
    stop(
      "sccomp says: some factor values were not present in the original training data. \n",
      new_data
      )

  new_data =  old_data |> bind_rows( new_data )
  
  new_X =
    new_data |>
    get_design_matrix(
      
      # Drop random intercept
      formula_composition |>
        as.character() |>
        str_remove_all("\\+ ?\\(.+\\|.+\\)") |>
        paste(collapse="") |>
        as.formula(),
      !!.sample
    ) |>
    tail(nrow_new_data) %>%
    
    # Remove columns that are not in the original design matrix
    .[,colnames(.) %in% colnames(model_input$X), drop=FALSE]
  
  # Check that all effect combination were present when the model was fitted
  check_missing_parameters(
    new_X |> colnames(), 
    model_input$X |> colnames()
  ) 
  
  X_which =
    colnames(new_X) |>
    match(
      model_input$X %>%
        colnames()
    ) |>
    na.omit() |>
    as.array()
  
  # Variability
  if(is.null(formula_variability)) formula_variability =  .data |> attr("formula_variability")
  
  new_Xa =
    new_data |>
    get_design_matrix(
      
      # Drop random intercept
      formula_variability |>
        as.character() |>
        str_remove_all("\\+ ?\\(.+\\|.+\\)") |>
        paste(collapse="") |>
        as.formula(),
      !!.sample
    ) |>
    tail(nrow_new_data) %>%
    
    # Remove columns that are not in the original design matrix
    .[,colnames(.) %in% colnames(model_input$Xa), drop=FALSE]
  
  XA_which =
    colnames(new_Xa) |>
    match(
      model_input %$%
        Xa %>%
        colnames()
    ) |>
    na.omit() |>
    as.array()
  
  # If I want to replicate data with intercept and I don't have intercept in my fit
  create_intercept =
    model_input %$% intercept_in_design |> not() &
    "(Intercept)" %in% colnames(new_X)
  if(create_intercept) warning("sccomp says: your estimated model is intercept free, while your desired replicated data do have an intercept term. The intercept estimate will be calculated averaging your first factor in your formula ~ 0 + <factor>. If you don't know the meaning of this warning, this is likely undesired, and please reconsider your formula for replicate_data()")
  
  # Original grouping
  original_grouping_names = .data |> attr("formula_composition") |> formula_to_random_effect_formulae() |> pull(grouping)
  
  # Random intercept
  random_effect_elements = parse_formula_random_effect(formula_composition)
  
  random_effect_grouping =
    new_data %>%
    
    get_random_effect_design2(
      !!.sample,
      formula_composition
    )
  
  
  # if(random_effect_elements |> nrow() |> equals(0)) {
  #   
  # }
  if((random_effect_grouping$grouping %in% original_grouping_names[1]) |> any()) {
    
    # HAVE TO DEBUG
    new_X_random_effect =
      random_effect_grouping |>
      filter(grouping==original_grouping_names[1]) |> 
      mutate(design_matrix = map(
        design,
        ~ ..1 |>
          select(!!.sample, group___label, value) |>
          pivot_wider(names_from = group___label, values_from = value) |>
          mutate(across(everything(), ~ .x |> replace_na(0)))
      )) |>
      
      # Merge
      pull(design_matrix) |> 
      _[[1]] |> 
      as_matrix(rownames = quo_name(.sample))  |>
      tail(nrow_new_data)
    
    # Check that all effect combination were present when the model was fitted
    check_missing_parameters(
      new_X_random_effect |> colnames(), 
      model_input %$% X_random_effect |> colnames()
    ) 
    
    # I HAVE TO KEEP GROUP NAME IN COLUMN NAME
    X_random_effect_which =
      colnames(new_X_random_effect) |>
      match(
        model_input %$%
          X_random_effect %>%
          colnames()
      ) |>
      as.array()
    
    # Check if I have column in the new design that are not in the old one
    missing_columns = new_X_random_effect |> colnames() |> setdiff(colnames(model_input$X_random_effect))
    if(missing_columns |> length() > 0)
      stop(glue("sccomp says: the columns in the design matrix {paste(missing_columns, collapse= ' ,')} are missing from the design matrix of the estimate-input object. Please make sure your new model is a sub-model of your estimated one."))
    
  }
  else{
    X_random_effect_which = array()[0]
    new_X_random_effect = matrix(rep(0, nrow_new_data))[,0, drop=FALSE]
    
  }
  if((random_effect_grouping$grouping %in% original_grouping_names[2]) |> any()){
    # HAVE TO DEBUG
    new_X_random_effect_2 =
      random_effect_grouping |>
      filter(grouping==original_grouping_names[2]) |> 
      mutate(design_matrix = map(
        design,
        ~ ..1 |>
          select(!!.sample, group___label, value) |>
          pivot_wider(names_from = group___label, values_from = value) |>
          mutate(across(everything(), ~ .x |> replace_na(0)))
      )) |>
      
      # Merge
      pull(design_matrix) |> 
      _[[1]] |> 
      as_matrix(rownames = quo_name(.sample))  |>
      tail(nrow_new_data)
    
    
    # Check that all effect combination were present when the model was fitted
    check_missing_parameters(
      new_X_random_effect_2 |> colnames(), 
      model_input %$% X_random_effect_2 |> colnames()
    ) 
      
    # DUPLICATE
    X_random_effect_which_2 =
      colnames(new_X_random_effect_2) |>
      match(
        model_input %$%
          X_random_effect_2 %>%
          colnames()
      ) |>
      as.array()
    
  }
  else{
    X_random_effect_which_2 = array()[0]
    new_X_random_effect_2 = matrix(rep(0, nrow_new_data))[,0, drop=FALSE]
  }
  
  # New X
  model_input$X_original = model_input$X
  model_input$X = new_X
  model_input$Xa = new_Xa
  model_input$N_original = model_input$N
  model_input$N = nrow_new_data
  model_input$exposure = new_exposure
  
  model_input$X_random_effect = new_X_random_effect
  model_input$X_random_effect_2 = new_X_random_effect_2
  
  model_input$ncol_X_random_eff_new = ncol(new_X_random_effect) |> c(ncol(new_X_random_effect_2))
  
  
  
  number_of_draws_in_the_fit = attr(.data, "fit") |>  get_output_samples()
  
  # To avoid error in case of a NULL posterior sample
  number_of_draws = min(number_of_draws, number_of_draws_in_the_fit)
  
  # Load model
  mod_rng = load_model("glm_multi_beta_binomial_generate_data", threads = cores)
  
  
  # Generate quantities
  mod_rng |> sample_safe(
    generate_quantities_fx,
    attr(.data, "fit")$draws(format = "matrix")[
      sample(seq_len(number_of_draws_in_the_fit), size=number_of_draws),, drop=FALSE
    ],
    data = model_input |> c(list(
      
      # Add subset of coefficients
      length_X_which = length(X_which),
      length_XA_which = length(XA_which),
      X_which = X_which,
      XA_which = XA_which,
      
      # Random intercept
      X_random_effect_which = X_random_effect_which,
      X_random_effect_which_2 = X_random_effect_which_2,
      length_X_random_effect_which = length(X_random_effect_which) |> c(length(X_random_effect_which_2)),
      
      # Should I create intercept for generate quantities
      create_intercept = create_intercept
      
    )),
    seed = mcmc_seed, 
    threads_per_chain = 1
    
  )
}

get_model_from_data = function(file_compiled_model, model_code){
  if(file.exists(file_compiled_model))
    readRDS(file_compiled_model)
  else {
    model_generate = stan_model(model_code = model_code)
    model_generate  %>% saveRDS(file_compiled_model)
    model_generate
    
  }
}

add_formula_columns = function(.data, .original_data, .sample,  formula_composition){
  
  .sample = enquo(.sample)
  
  formula_elements = parse_formula(formula_composition)
  
  # If no formula return the input
  if(length(formula_elements) == 0) return(.data)
  
  # Get random intercept
  .grouping_for_random_effect = parse_formula_random_effect(formula_composition) |> pull(grouping) |> unique()
  
  data_frame_formula =
    .original_data %>%
    as_tibble() |>
    select( !!.sample, formula_elements, any_of(.grouping_for_random_effect) ) %>%
    distinct()
  
  .data |>
    left_join(data_frame_formula, by = quo_name(.sample) )
  
}

#' chatGPT - Remove Specified Regex Pattern from Each String in a Vector
#'
#' This function takes a vector of strings and a regular expression pattern.
#' It removes occurrences of the pattern from each string, except where the pattern
#' is found inside backticks. The function returns a vector of cleaned strings.
#'
#' @param text_vector A character vector with the strings to be processed.
#' @param regex A character string containing a regular expression pattern to be removed
#' from the text.
#'
#' @return A character vector with the regex pattern removed from each string.
#' Occurrences of the pattern inside backticks are not removed.
#'
#' @examples
#' texts <- c("A string with (some) parentheses and `a (parenthesis) inside` backticks",
#'            "Another string with (extra) parentheses")
#' cleaned_texts <- str_remove_all_ignoring_if_inside_backquotes(texts, "\\(")
#' print(cleaned_texts)
#' 
#' @noRd
str_remove_all_ignoring_if_inside_backquotes <- function(text_vector, regex) {
  # Nested function to handle regex removal for a single string
  remove_regex_chars <- function(text, regex) {
    inside_backticks <- FALSE
    result <- ""
    skip <- 0
    
    chars <- strsplit(text, "")[[1]]
    for (i in seq_along(chars)) {
      if (skip > 0) {
        skip <- skip - 1
        next
      }
      
      char <- chars[i]
      if (char == "`") {
        inside_backticks <- !inside_backticks
        result <- paste0(result, char)
      } else if (!inside_backticks) {
        # Check the remaining text against the regex
        remaining_text <- paste(chars[i:length(chars)], collapse = "")
        match <- regexpr(regex, remaining_text)
        
        if (attr(match, "match.length") > 0 && match[1] == 1) {
          # Skip the length of the matched text
          skip <- attr(match, "match.length") - 1
          next
        } else {
          result <- paste0(result, char)
        }
      } else {
        result <- paste0(result, char)
      }
    }
    
    return(result)
  }
  
  # Apply the function to each element in the vector
  sapply(text_vector, remove_regex_chars, regex)
}


#' chatGPT - Split Each String in a Vector by a Specified Regex Pattern
#'
#' This function takes a vector of strings and a regular expression pattern. It splits
#' each string based on the pattern, except where the pattern is found inside backticks.
#' The function returns a list, with each element being a vector of the split segments
#' of the corresponding input string.
#'
#' @param text_vector A character vector with the strings to be processed.
#' @param regex A character string containing a regular expression pattern used for splitting
#' the text.
#'
#' @return A list of character vectors. Each list element corresponds to an input string
#' from `text_vector`, split according to `regex`, excluding occurrences inside backticks.
#'
#' @examples
#' texts <- c("A string with, some, commas, and `a, comma, inside` backticks",
#'            "Another string, with, commas")
#' split_texts <- split_regex_chars_from_vector(texts, ",")
#' print(split_texts)
#' 
#' @noRd
str_split_ignoring_if_inside_backquotes <- function(text_vector, regex) {
  # Nested function to handle regex split for a single string
  split_regex_chars <- function(text, regex) {
    inside_backticks <- FALSE
    result <- c()
    current_segment <- ""
    
    chars <- strsplit(text, "")[[1]]
    for (i in seq_along(chars)) {
      char <- chars[i]
      if (char == "`") {
        inside_backticks <- !inside_backticks
        current_segment <- paste0(current_segment, char)
      } else if (!inside_backticks) {
        # Check the remaining text against the regex
        remaining_text <- paste(chars[i:length(chars)], collapse = "")
        match <- regexpr(regex, remaining_text)
        
        if (attr(match, "match.length") > 0 && match[1] == 1) {
          # Add current segment to result and start a new segment
          result <- c(result, current_segment)
          current_segment <- ""
          # Skip the length of the matched text
          skip <- attr(match, "match.length") - 1
          i <- i + skip
        } else {
          current_segment <- paste0(current_segment, char)
        }
      } else {
        current_segment <- paste0(current_segment, char)
      }
    }
    
    # Add the last segment to the result
    result <- c(result, current_segment)
    return(result)
  }
  
  # Apply the function to each element in the vector
  lapply(text_vector, split_regex_chars, regex)
}


#' chatGPT - Check for Valid Column Names in Tidyverse Context
#'
#' This function checks if each given column name in a vector contains only valid characters 
#' (letters, numbers, periods, and underscores) and does not start with a digit 
#' or an underscore, which are the conditions for a valid column name in `tidyverse`.
#'
#' @param column_names A character vector representing the column names to be checked.
#'
#' @return A logical vector: `TRUE` for each column name that contains only valid characters 
#' and does not start with a digit or an underscore; `FALSE` otherwise.
#'
#' @examples
#' contains_only_valid_chars_for_column(c("valid_column", "invalid column", "valid123", 
#' "123startWithNumber", "_startWithUnderscore"))
#'
#' @noRd
contains_only_valid_chars_for_column <- function(column_names) {
  # Function to check a single column name
  check_validity <- function(column_name) {
    # Regex pattern for valid characters (letters, numbers, periods, underscores)
    valid_char_pattern <- "[A-Za-z0-9._]"
    
    # Check if all characters in the string match the valid pattern
    all_chars_valid <- stringr::str_detect(column_name, paste0("^", valid_char_pattern, "+$"))
    
    # Check for leading digits or underscores
    starts_with_digit_or_underscore <- stringr::str_detect(column_name, "^[0-9_]")
    
    return(all_chars_valid && !starts_with_digit_or_underscore)
  }
  
  # Apply the check to each element of the vector
  sapply(column_names, check_validity)
}

#' Check Sample Consistency of Factors
#'
#' This function checks for each sample in the provided data frame if the number of unique
#' covariate values from a specified formula matches the number of samples. It is useful for
#' verifying data consistency before statistical analysis. The function stops and throws an
#' error if inconsistencies are found.
#'
#' @importFrom dplyr select
#' @importFrom dplyr filter
#' @importFrom dplyr mutate
#' @importFrom dplyr pull
#' @importFrom dplyr distinct
#' @importFrom tidyr pivot_longer
#' @importFrom purrr map_lgl
#'
#' @param .data A data frame containing the samples and covariates.
#' @param my_formula A formula specifying the covariates to check, passed as a string.
#'
#' @details The function selects the sample and covariates based on `my_formula`, pivots
#' the data longer so each row represents a unique sample-covariate combination, nests
#' the data by covariate name, and checks if the number of unique sample-covariate
#' pairs matches the number of samples for each covariate.
#'
#' @return This function does not return a value; it stops with an error message if any
#' inconsistencies are found.
#'
#' @noRd
#' @keywords internal
check_sample_consistency_of_factors = function(.data, my_formula, .sample, .cell_group){
  
  .sample = enquo(.sample)
  .cell_group = enquo(.cell_group)
  
  # Check that I have one set of covariates per sample
  first_cell_group = .data |> pull(!!.cell_group) |> _[[1]]
  
  any_covariate_not_matching_sample_size = 
    .data |> 
    filter(!!.cell_group == first_cell_group) |> 
    select(!!.sample, parse_formula(my_formula)) |> 
    pivot_longer(-!!.sample, values_transform = as.character) |> 
    nest(data = -name) |> 
    mutate(correct_size = map_lgl(data,
                                  ~ 
                                    (.x |> distinct(!!.sample, value) |> nrow()) <= 
                                    (.x |> distinct(!!.sample) |> nrow())
    )) |> 
    filter(!correct_size)
  
  if( any_covariate_not_matching_sample_size |> nrow() > 0 ) stop(
    sprintf("sccomp says: your \"%s\" factor(s) is(are) mismatched across samples. ", any_covariate_not_matching_sample_size |> pull(name) |> paste(collapse = ", ")),
    "For example, sample_bar having more than one value for factor_foo. ",
    "For sample_bar you should have one value for factor_foo. consistent across groups (e.g. cell types)."
  )
  
}



#' chatGPT - Intelligently Remove Surrounding Brackets from Each String in a Vector
#'
#' This function processes each string in a vector and removes surrounding brackets if the content
#' within the brackets includes any of '+', '-', or '*', and if the brackets are not 
#' within backticks. This is particularly useful for handling formula-like strings.
#'
#' @param text A character vector with strings from which the brackets will be removed based on
#' specific conditions.
#'
#' @return A character vector with the specified brackets removed from each string.
#'
#' @examples
#' str_remove_brackets_from_formula_intelligently(c("This is a test (with + brackets)", "`a (test) inside` backticks", "(another test)"))
#'
#' @noRd
str_remove_brackets_from_formula_intelligently <- function(text) {
  # Function to remove brackets from a single string
  remove_brackets_single <- function(s) {
    inside_backticks <- FALSE
    bracket_depth <- 0
    valid_bracket_content <- FALSE
    result <- ""
    bracket_content <- ""
    
    chars <- strsplit(s, "")[[1]]
    
    for (i in seq_along(chars)) {
      char <- chars[i]
      
      if (char == "`") {
        inside_backticks <- !inside_backticks
      }
      
      if (!inside_backticks) {
        if (char == "(") {
          bracket_depth <- bracket_depth + 1
          if (bracket_depth > 1) {
            bracket_content <- paste0(bracket_content, char)
          }
          next
        } else if (char == ")") {
          bracket_depth <- bracket_depth - 1
          if (bracket_depth == 0) {
            if (grepl("[\\+\\-\\*]", bracket_content)) {
              result <- paste0(result, bracket_content)
            } else {
              result <- paste0(result, "(", bracket_content, ")")
            }
            bracket_content <- ""
            next
          }
        }
        
        if (bracket_depth >= 1) {
          bracket_content <- paste0(bracket_content, char)
        } else {
          result <- paste0(result, char)
        }
      } else {
        result <- paste0(result, char)
      }
    }
    
    return(result)
  }
  
  # Apply the function to each element in the vector
  sapply(text, remove_brackets_single)
}



# Negation
not = function(is){	!is }

#' Convert array of quosure (e.g. c(col_a, col_b)) into character vector
#'
#' @keywords internal
#' @noRd
#'
#' @importFrom rlang quo_name
#' @importFrom rlang quo_squash
#'
#' @param v A array of quosures (e.g. c(col_a, col_b))
#'
#' @return A character vector
quo_names <- function(v) {
  
  v = quo_name(quo_squash(v))
  gsub('^c\\(|`|\\)$', '', v) |>
    strsplit(', ') |>
    unlist()
}

#' Add class to abject
#'
#' @keywords internal
#' @noRd
#'
#' @param var A tibble
#' @param name A character name of the attribute
#'
#' @return A tibble with an additional attribute
add_class = function(var, name) {
  
  if(!name %in% class(var)) class(var) <- append(class(var),name, after = 0)
  
  var
}


#' Get Output Samples from a Stan Fit Object
#'
#' This function retrieves the number of output samples from a Stan fit object, 
#' supporting different methods (MHC and Variational) based on the available data within the object.
#'
#' @param fit A `stanfit` object, which is the result of fitting a model via Stan.
#' @return The number of output samples used in the Stan model. 
#'         Returns from MHC if available, otherwise from Variational inference.
#' @examples
#' # Assuming 'fit' is a stanfit object obtained from running a Stan model
#' print("samples_count = get_output_samples(fit)")
#'
#' @export
#' 
get_output_samples = function(fit){
  
  # Check if the output_samples field is present in the metadata of the fit object
  # This is generally available when the model is fit using MHC (Markov chain Monte Carlo)
  if(!is.null(fit$metadata()$output_samples)) {
    # Return the output_samples from the metadata
    fit$metadata()$output_samples
  }
  
  # If the output_samples field is not present, check for iter_sampling
  # This occurs typically when the model is fit using Variational inference methods
  else if(!is.null(fit$metadata()$iter_sampling)) {
    # Return the iter_sampling from the metadata
    fit$metadata()$iter_sampling
  }
  else
    fit$metadata()$num_psis_draws
}


#' Load, Compile, and Cache a Stan Model
#'
#' This function attempts to load a precompiled Stan model using the `instantiate` package.
#' If the model is not found in the cache or force recompilation is requested, it will locate
#' the Stan model file within the `sccomp` package, compile it using `cmdstanr`, and save the 
#' compiled model to the cache directory for future use.
#'
#' @param name A character string representing the name of the Stan model (without the `.stan` extension).
#' @param cache_dir A character string representing the path to the cache directory where compiled models are saved. 
#' Defaults to `sccomp_stan_models_cache_dir`.
#' @param force A logical value. If `TRUE`, the model will be recompiled even if it exists in the cache. 
#' Defaults to `FALSE`.
#' @param threads An integer specifying the number of threads to use for compilation. 
#' Defaults to `1`.
#' 
#' @return A compiled Stan model object from `cmdstanr`.
#' 
#' @importFrom instantiate stan_package_model
#' @importFrom instantiate stan_package_compile
#' 
#' @noRd
#' 
#' @examples
#' \donttest{
#'   model <- load_model("glm_multi_beta_binomial_", "~/cache", force = FALSE, threads = 1)
#' }
load_model <- function(name, cache_dir = sccomp_stan_models_cache_dir, force=FALSE, threads = 1) {
  
  
  # tryCatch({
  #   # Attempt to load a precompiled Stan model using the instantiate package
  #   instantiate::stan_package_model(
  #     name = name,
  #     package = "sccomp"
  #   )
  # }, error = function(e) {
  # Try to load the model from cache
  
  # RDS compiled model
  cache_dir |> dir.create(showWarnings = FALSE, recursive = TRUE)
  cache_file <- file.path(cache_dir, paste0(name, ".rds"))
  
  # .STAN raw model
  stan_model_path <- system.file("stan", paste0(name, ".stan"), package = "sccomp")
  
  if (file.exists(cache_file) & !force) {
    message("Loading model from cache...")
    return(readRDS(cache_file))
  }
  
  # If loading the precompiled model fails, find the Stan model file within the package
  message("Precompiled model not found. Compiling the model...")
  
  # Compile the Stan model using cmdstanr with threading support enabled
  instantiate::stan_package_compile(
    stan_model_path, 
    cpp_options = list(stan_threads = TRUE),
    force_recompile = TRUE, 
    threads = threads, 
    dir = system.file("stan", package = "sccomp")
  )
  mod = instantiate::stan_package_model(
    name = name, 
    package = "sccomp", 
    compile = TRUE,
    cpp_options = list(stan_threads = TRUE)
  ) |> suppressWarnings()
  
  # Save the compiled model object to cache
  saveRDS(mod, file = cache_file)
  message("Model compiled and saved to cache successfully.")
  
  return(mod)
  # })
  
}

#' Check and Install cmdstanr and CmdStan
#'
#' This function checks if the `cmdstanr` package and CmdStan are installed. 
#' If they are not installed, it installs them automatically in non-interactive sessions
#' or asks for permission to install them in interactive sessions.
#'
#' @importFrom instantiate stan_cmdstan_exists
#' @importFrom rlang check_installed
#' @importFrom rlang abort
#' @importFrom rlang check_installed
#' @return NULL
#' 
#' @noRd
check_and_install_cmdstanr <- function() {
  
  # Check if cmdstanr is installed
  # from https://github.com/wlandau/instantiate/blob/33989d74c26f349e292e5efc11c267b3a1b71d3f/R/utils_assert.R#L114
  
  stan_error <- function(message = NULL) {
    stan_stop(
      message = message,
      class = c("stan_error", "stan")
    )
  }
  
  stan_stop <- function(message, class) {
    old <- getOption("rlang_backtrace_on_error")
    on.exit(options(rlang_backtrace_on_error = old))
    options(rlang_backtrace_on_error = "none")
    abort(message = message, class = class, call = emptyenv())
  }
  
  tryCatch(
    rlang::check_installed(
      pkg = "cmdstanr",
      reason = paste(
        "The {cmdstanr} package is required in order to install",
        "CmdStan and run Stan models. Please install it manually using",
        "install.packages(pkgs = \"cmdstanr\",",
        "repos = c(\"https://mc-stan.org/r-packages/\", getOption(\"repos\"))"
      )
    ),
    error = function(e) {
      clear_stan_model_cache()
      stan_error(conditionMessage(e))
    }
  )
  
  # Check if CmdStan is installed
  if (!stan_cmdstan_exists()) {
    
    clear_stan_model_cache()
    
    stop(
      "cmdstan is required to proceed.\n\n",
      "You can install CmdStan by running the following command:\n",
      "cmdstanr::check_cmdstan_toolchain(fix = TRUE)\n",
      "cmdstanr::install_cmdstan()\n",
      "This will install the latest version of CmdStan. For more information, visit:\n",
      "https://mc-stan.org/users/interfaces/cmdstan"
    )
  }
}


drop_environment <- function(obj) {
  # Check if the object has an environment
  if (!is.null(environment(obj))) {
    environment(obj) <- new.env()
  }
  return(obj)
}

check_missing_parameters <- function(effects, model_effects) {
  # Find missing parameters
  missing_parameters <- 
    effects |> 
    setdiff(model_effects)
  
  # If there are any missing parameters, stop and show an error message
  if (length(missing_parameters) > 0) {
    stop(
      "sccomp says: Some of the parameters present in the data provided were not present when the model was fitted. For example:\n",
      paste(missing_parameters[1:min(3, length(missing_parameters))], collapse = "\n"),
      if (length(missing_parameters) > 3) "\n..."
    )
  }
}

library(dplyr)

library(dplyr)

harmonise_factor_levels <- function(dataframe_query, dataframe_reference) {
  # 1. Identify factor columns in the reference
  factor_cols <- names(dataframe_reference)[sapply(dataframe_reference, is.factor)]
  
  # 2. For each reference factor column, if the query has that column, coerce it to factor
  for (col in factor_cols) {
    if (col %in% names(dataframe_query)) {
      # Use the reference's levels
      ref_levels <- levels(dataframe_reference[[col]])
      
      # Force the query to adopt these levels, whether or not it was originally a factor
      dataframe_query[[col]] <- factor(dataframe_query[[col]], levels = ref_levels)
    }
  }
  
  # 3. Return ONLY the updated query
  dataframe_query
}

#' Print Tibble in Red
#'
#' This function captures the console output of printing a tibble,
#' colours it in red and returns the coloured text.
#'
#' @param tbl A data frame or tibble to be printed and coloured in red.
#'
#' @return A character string containing the coloured tibble output.
#'
#' @importFrom crayon red
#' @noRd
print_red_tibble <- function(tbl) {
  # Capture the console output of printing the tibble
  example_text <- capture.output(print(tbl))
  
  # Combine all lines into one block and colour it in red
  red(paste(example_text, collapse = "\n"))
}

#' Check if a Sample Column is a Unique Identifier
#'
#' This function checks if the `.sample` column in a wide dataset is truly
#' a unique identifier. If not, it throws an error containing the problematic
#' rows in red text.
#'
#' @param data_wide A data frame or tibble in wide format.
#' @param .sample   An unquoted column name indicating the sample column to check.
#'
#' @return Returns the original `data_wide` if `.sample` is unique. Otherwise,
#'   throws an error showing the problematic rows in red.
#'
#' @importFrom rlang enquo
#' @importFrom rlang quo_name
#' @importFrom dplyr count
#' @importFrom dplyr add_count
#' @importFrom dplyr filter
#' @importFrom dplyr pull
#' @importFrom dplyr select
#' @importFrom glue glue
#' @noRd
check_if_sample_is_a_unique_identifier <- function(data_wide, .sample) {
  .sample <- enquo(.sample)
  
  if (
    data_wide |>
    count(!!.sample) |>
    pull(n) |>
    max() > 1
  ) {
    stop(
      paste(
        glue("sccomp says: .sample column `{quo_name(.sample)}` should be a unique identifier, with a unique combination of factors. For example Sample_A cannot have both treated and untreated conditions in your input"),
        data_wide |>
          add_count(!!.sample, name = "n___") |>
          filter(n___ > 1) |>
          select(-n___) |>
          print_red_tibble(),
        sep = "\n\n"
      )
    )
  } else {
    return(data_wide)
  }
}
stemangiola/sccomp documentation built on Feb. 13, 2025, 8:16 p.m.